MLIR 22.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31
32#include <type_traits>
33
34using namespace mlir;
35using namespace mlir::tosa;
36
37// Helper function to materialize the semantically correct compare and select
38// operations given a binary operation with a specific NaN propagation mode.
39//
40// In the case of "PROPAGATE" semantics no compare and selection is required and
41// this function does nothing.
42//
43// In the case of "IGNORE" semantics this function materializes a comparison of
44// the current operands to the op which will return true for any NaN
45// argument and then selects between the non-NaN operation argument and the
46// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
47// code:
48//
49// In the case that the op is operating on non floating point types we ignore
50// the attribute completely, this is consistent with the TOSA spec which has
51// the following wording: "This attribute is ignored by non floating-point
52// types."
53//
54// binary<op>(lhs, rhs):
55// result = op(lhs, rhs)
56// if lhs == NaN return rhs
57// if rhs == NaN return lhs
58// return result
59template <typename OpTy>
60static Value
63 // NaN propagation has no meaning for non floating point types.
64 if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
65 return result;
66
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
69 return result;
70
71 // Unordered comparison of NaN against itself will always return true.
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO, lhs, lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO, rhs, rhs);
76 Value rhsOrResult =
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
79 rhsOrResult);
80}
81
83 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
84 ConversionPatternRewriter &rewriter) {
85 Location loc = op->getLoc();
86 auto elementTy =
87 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
88
89 // tosa::AbsOp
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
92
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
95 rewriter.getZeroAttr(elementTy));
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
98 }
99
100 // tosa::AddOp
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
103
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
106
107 // tosa::SubOp
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
110
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
113
114 // tosa::IntDivOp
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
117
118 // tosa::ReciprocalOp
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
120 auto one =
121 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
123 }
124
125 // tosa::MulOp
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
128 DenseElementsAttr shiftElem;
129 bool shiftIsConstant = true;
130 int32_t shift = 0;
131 if (matchPattern(shiftVal, m_Constant(&shiftElem)))
132 shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
133 else
134 shiftIsConstant = false;
135
136 if (isa<FloatType>(elementTy)) {
137 if (shift != 0) {
138 (void)rewriter.notifyMatchFailure(op,
139 "Cannot have shift value for float");
140 return nullptr;
141 }
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
143 args[1]);
144 }
145
146 if (isa<IntegerType>(elementTy)) {
147 Value a = args[0];
148 Value b = args[1];
149
150 if (shift > 0 || !shiftIsConstant) {
151 Value shiftConst;
152 if (shiftIsConstant)
153 shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
154 /*bitwidth=*/8);
155
156 if (!a.getType().isInteger(32))
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
158
159 if (!b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
165 auto result =
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
168
169 return result;
170 }
171
172 int aWidth = a.getType().getIntOrFloatBitWidth();
173 int bWidth = b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
175
176 if (aWidth < cWidth)
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
178 if (bWidth < cWidth)
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
180
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
182 }
183 }
184
185 // tosa::NegateOp
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
188
189 int64_t inZp = 0, outZp = 0;
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
194 if (hasInZp)
195 inZp = *maybeInZp;
196 if (hasOutZp)
197 outZp = *maybeOutZp;
198
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
201
202 if (isa<IntegerType>(elementTy)) {
203 if (hasInZp && hasOutZp && !inZp && !outZp) {
204 auto constant = arith::ConstantOp::create(
205 rewriter, loc, IntegerAttr::get(elementTy, 0));
206 return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
207 args[0]);
208 }
209
210 Value zpAddValue;
211 Type intermediateType;
212 // Compute the maximum value that can occur in the intermediate buffer.
213 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
214 int intermediateBitWidth = 64;
215
216 if (hasInZp && hasOutZp) {
217 // Compute the maximum value that can occur in the intermediate buffer.
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
221 std::abs(zpAdd) + 1;
222
223 // Convert that maximum value into the maximum bitwidth needed to
224 // represent it. We assume 48-bit numbers may be supported further in
225 // the pipeline.
226 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227 intermediateBitWidth = 16;
228 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229 intermediateBitWidth = 32;
230 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231 intermediateBitWidth = 48;
232 }
233
234 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
235 zpAddValue = arith::ConstantOp::create(
236 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
237 } else {
238 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
239 auto arg1 =
240 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
241 auto arg2 =
242 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
243 zpAddValue =
244 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
245 }
246
247 // The negation can be applied by doing:
248 // outputValue = inZp + outZp - inputValue
249 auto ext =
250 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
251 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
252
253 // Clamp to the negation range.
255 rewriter, loc, intermediateType,
256 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
258 rewriter, loc, intermediateType,
259 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
260 auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
261
262 // Truncate to the final value.
263 return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
264 }
265 }
266
267 // tosa::BitwiseAndOp
268 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
269 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
270
271 // tosa::BitwiseOrOp
272 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
273 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
274
275 // tosa::BitwiseNotOp
276 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
277 auto allOnesAttr = rewriter.getIntegerAttr(
278 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
279 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
280 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
281 }
282
283 // tosa::BitwiseXOrOp
284 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
285 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
286
287 // tosa::LogicalLeftShiftOp
288 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
290
291 // tosa::LogicalRightShiftOp
292 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
293 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
294
295 // tosa::ArithmeticRightShiftOp
296 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
297 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
298 auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
299 if (!round) {
300 return result;
301 }
302
303 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
304 auto one = arith::ConstantOp::create(rewriter, loc,
305 IntegerAttr::get(elementTy, 1));
306 auto zero = arith::ConstantOp::create(rewriter, loc,
307 IntegerAttr::get(elementTy, 0));
308 auto i1zero =
309 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
310 auto i1one =
311 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
312
313 // Checking that input2 != 0
314 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
315 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
316
317 // Checking for the last bit of input1 to be 1
318 auto subtract =
319 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
320 auto shifted =
321 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
322 ->getResults();
323 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
325 auto isInputOdd =
326 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
327 // shifted, truncated, isInputOdd can be poison when input2 is 0.
328 auto shouldRound = arith::SelectOp::create(
329 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
330 auto extended =
331 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
332 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
333 }
334
335 // tosa::ClzOp
336 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
337 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
338 }
339
340 // tosa::LogicalAnd
341 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
342 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
343
344 // tosa::LogicalNot
345 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
346 auto one = arith::ConstantOp::create(rewriter, loc,
347 rewriter.getIntegerAttr(elementTy, 1));
348 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
349 }
350
351 // tosa::LogicalOr
352 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
353 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
354
355 // tosa::LogicalXor
356 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
357 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
358
359 // tosa::PowOp
360 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
362
363 // tosa::RsqrtOp
364 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
366
367 // tosa::LogOp
368 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
370
371 // tosa::ExpOp
372 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
374
375 // tosa::SinOp
376 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
378
379 // tosa::CosOp
380 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
382
383 // tosa::TanhOp
384 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
385 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
386
387 // tosa::ErfOp
388 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
389 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
390
391 // tosa::GreaterOp
392 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
393 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
394 args[0], args[1]);
395
396 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
397 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
398 args[0], args[1]);
399
400 // tosa::GreaterEqualOp
401 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
402 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
403 args[0], args[1]);
404
405 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
406 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
407 args[0], args[1]);
408
409 // tosa::EqualOp
410 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
411 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
412 args[0], args[1]);
413
414 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
415 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
416 args[0], args[1]);
417
418 // tosa::SelectOp
419 if (isa<tosa::SelectOp>(op)) {
420 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
421 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
422 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
423 }
424
425 // tosa::MaximumOp
426 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
427 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
428 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
429 rewriter, args[0], args[1], max);
430 }
431
432 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
433 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
434 }
435
436 // tosa::MinimumOp
437 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
438 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
439 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
440 rewriter, args[0], args[1], min);
441 }
442
443 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
444 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
445 }
446
447 // tosa::CeilOp
448 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
449 return math::CeilOp::create(rewriter, loc, resultTypes, args);
450
451 // tosa::FloorOp
452 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
453 return math::FloorOp::create(rewriter, loc, resultTypes, args);
454
455 // tosa::ClampOp
456 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
457 bool losesInfo = false;
458 APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
459 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
460 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
461 APFloat::rmNearestTiesToEven, &losesInfo);
462 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
463 APFloat::rmNearestTiesToEven, &losesInfo);
464 auto min = arith::ConstantOp::create(
465 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
466 auto max = arith::ConstantOp::create(
467 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
468 auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
469
470 auto clampOp = llvm::cast<tosa::ClampOp>(op);
471 const auto nanMode = clampOp.getNanMode();
472
473 // NaN propagation has no meaning for non floating point types.
474 if (!isa<FloatType>(elementTy))
475 return result;
476
477 // In the case of "PROPAGATE" semantics no compare and selection is
478 // required.
479 if (nanMode == NanPropagationMode::PROPAGATE)
480 return result;
481
482 // In the case of "IGNORE" semantics materialize a comparison
483 // of the current operand to the reduction which will return true for a NaN
484 // argument and then selects between the initial reduction value and the
485 // calculated result based on whether the argument is NaN or not. In pseudo
486 // code:
487 //
488 // reduce<op>(x, init):
489 // result = op(init, x)
490 // return init if x == NaN else result
491
492 // Unordered comparison of NaN against itself will always return true.
493 Value isNaN = arith::CmpFOp::create(
494 rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
495 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
496 // is NaN.
497 return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
498 }
499
500 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
501 auto intTy = cast<IntegerType>(elementTy);
502 int64_t min =
503 cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
504 int64_t max =
505 cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
506
507 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
508 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
509 if (intTy.isUnsignedInteger()) {
510 minRepresentable = 0;
511 if (intTy.getIntOrFloatBitWidth() <= 63) {
512 maxRepresentable =
513 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
514 .getZExtValue();
515 }
516 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
517 // Ensure that min & max fit into signed n-bit constants.
518 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
519 .getSExtValue();
520 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
521 .getSExtValue();
522 }
523 // Ensure that the bounds are representable as n-bit signed/unsigned
524 // integers.
525 min = std::max(min, minRepresentable);
526 max = std::max(max, minRepresentable);
527 min = std::min(min, maxRepresentable);
528 max = std::min(max, maxRepresentable);
529
530 auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
531 intTy.getIntOrFloatBitWidth());
532 auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
533 intTy.getIntOrFloatBitWidth());
534 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
535 intTy.isUnsignedInteger());
536 }
537
538 // tosa::SigmoidOp
539 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
540 auto one =
541 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
542 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
543 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
544 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
545 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
546 }
547
548 // tosa::CastOp
549 if (isa<tosa::CastOp>(op)) {
550 Type srcTy = elementTy;
551 Type dstTy = resultTypes.front();
552 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
553 (void)rewriter.notifyMatchFailure(op, "unsupported type");
554 return nullptr;
555 }
556
557 bool bitExtend =
559
560 if (srcTy == dstTy)
561 return args.front();
562
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
564 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
566
567 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
568 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
570
571 // 1-bit integers need to be treated as signless.
572 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
573 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
575
576 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
577 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
579
580 // Unsigned integers need an unrealized cast so that they can be passed
581 // to UIToFP.
582 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
583 auto unrealizedCast =
584 UnrealizedConversionCastOp::create(
585 rewriter, loc,
586 rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
587 .getResult(0);
588 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
589 unrealizedCast);
590 }
591
592 // All other si-to-fp conversions should be handled by SIToFP.
593 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
596
597 // Casting to boolean, floats need to only be checked as not-equal to zero.
598 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
599 Value zero = arith::ConstantOp::create(rewriter, loc,
600 rewriter.getFloatAttr(srcTy, 0.0));
601 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
602 args.front(), zero);
603 }
604
605 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
607
608 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
609 // Check whether neither int min nor int max can be represented in the
610 // input floating-point type due to too short exponent range.
611 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
612 APFloat::semanticsMaxExponent(fltSemantics)) {
613 // Use cmp + select to replace infinites by int min / int max. Other
614 // integral values can be represented in the integer space.
615 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
616 auto posInf = arith::ConstantOp::create(
617 rewriter, loc,
618 rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
619 APFloat::getInf(fltSemantics)));
620 auto negInf = arith::ConstantOp::create(
621 rewriter, loc,
622 rewriter.getFloatAttr(
624 APFloat::getInf(fltSemantics, /*Negative=*/true)));
625 auto overflow = arith::CmpFOp::create(
626 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
627 auto underflow = arith::CmpFOp::create(
628 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
629 auto intMin = arith::ConstantOp::create(
630 rewriter, loc,
631 rewriter.getIntegerAttr(
633 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
634 auto intMax = arith::ConstantOp::create(
635 rewriter, loc,
636 rewriter.getIntegerAttr(
638 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
639 auto maxClamped =
640 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
641 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
642 maxClamped);
643 }
644
645 auto intMinFP = arith::ConstantOp::create(
646 rewriter, loc,
647 rewriter.getFloatAttr(
649 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
650 .getSExtValue()));
651
652 // Check whether the mantissa has enough bits to represent int max.
653 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
654 dstTy.getIntOrFloatBitWidth() - 1) {
655 // Int min can also be represented since it is a power of two and thus
656 // consists of a single leading bit. Therefore we can clamp the input
657 // in the floating-point domain.
658
659 auto intMaxFP = arith::ConstantOp::create(
660 rewriter, loc,
661 rewriter.getFloatAttr(
663 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
664 .getSExtValue()));
665
666 Value clamped =
667 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
668 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
669 }
670
671 // Due to earlier check we know exponant range is big enough to represent
672 // int min. We can therefore rely on int max + 1 being representable as
673 // well because it's just int min with a positive sign. So clamp the min
674 // value and compare against that to select the max int value if needed.
675 auto intMaxPlusOneFP = arith::ConstantOp::create(
676 rewriter, loc,
677 rewriter.getFloatAttr(
679 static_cast<double>(
680 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
681 .getSExtValue()) +
682 1.0f));
683
684 auto intMax = arith::ConstantOp::create(
685 rewriter, loc,
686 rewriter.getIntegerAttr(
688 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
689 auto minClampedFP =
690 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
691 auto minClamped =
692 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
693 auto overflow = arith::CmpFOp::create(
694 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
695 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
696 minClamped);
697 }
698
699 // Casting to boolean, integers need to only be checked as not-equal to
700 // zero.
701 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
702 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
703 srcTy.getIntOrFloatBitWidth());
704 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
705 args.front(), zero);
706 }
707
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
709 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
711
712 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
713 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
714 }
715 }
716
717 (void)rewriter.notifyMatchFailure(
718 op, "unhandled op for linalg body calculation for elementwise op");
719 return nullptr;
720}
721
723
724// Emit an 'arith.constant' op for the given index if it has not been created
725// yet, or return an existing constant. This will prevent an excessive creation
726// of redundant constants, easing readability of emitted code for unit tests.
728 IndexPool &indexPool, int64_t index) {
729 auto [it, inserted] = indexPool.try_emplace(index);
730 if (inserted)
731 it->second =
732 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
733 return it->second;
734}
735
737 IndexPool &indexPool, Value tensor, int64_t index) {
738 auto indexValue = createIndex(rewriter, loc, indexPool, index);
739 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
740}
741
743 IndexPool &indexPool, Value tensor,
744 int64_t index) {
745 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
746 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
747 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
748 if (shapedType.isDynamicDim(index))
749 return getTensorDim(rewriter, loc, indexPool, tensor, index);
750 return rewriter.getIndexAttr(shapedType.getDimSize(index));
751}
752
753static bool operandsAndResultsRanked(Operation *operation) {
754 auto isRanked = [](Value value) {
755 return isa<RankedTensorType>(value.getType());
756 };
757 return llvm::all_of(operation->getOperands(), isRanked) &&
758 llvm::all_of(operation->getResults(), isRanked);
759}
760
761// Compute the runtime dimension size for dimension 'dim' of the output by
762// inspecting input 'operands', all of which are expected to have the same rank.
763// This function returns a pair {targetSize, masterOperand}.
764//
765// The runtime size of the output dimension is returned either as a statically
766// computed attribute or as a runtime SSA value.
767//
768// If the target size was inferred directly from one dominating operand, that
769// operand is returned in 'masterOperand'. If the target size is inferred from
770// multiple operands, 'masterOperand' is set to nullptr.
771static std::pair<OpFoldResult, Value>
773 ValueRange operands, int64_t dim) {
774 // If any input operand contains a static size greater than 1 for this
775 // dimension, that is the target size. An occurrence of an additional static
776 // dimension greater than 1 with a different value is undefined behavior.
777 for (auto operand : operands) {
778 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
779 if (ShapedType::isStatic(size) && size > 1)
780 return {rewriter.getIndexAttr(size), operand};
781 }
782
783 // Filter operands with dynamic dimension
784 auto operandsWithDynamicDim =
785 llvm::filter_to_vector(operands, [&](Value operand) {
786 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
787 });
788
789 // If no operand has a dynamic dimension, it means all sizes were 1
790 if (operandsWithDynamicDim.empty())
791 return {rewriter.getIndexAttr(1), operands.front()};
792
793 // Emit code that computes the runtime size for this dimension. If there is
794 // only one operand with a dynamic dimension, it is considered the master
795 // operand that determines the runtime size of the output dimension.
796 auto targetSize =
797 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
798 if (operandsWithDynamicDim.size() == 1)
799 return {targetSize, operandsWithDynamicDim[0]};
800
801 // Calculate maximum size among all dynamic dimensions
802 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
803 auto nextSize =
804 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
805 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
806 }
807 return {targetSize, nullptr};
808}
809
810// Compute the runtime output size for all dimensions. This function returns
811// a pair {targetShape, masterOperands}.
812static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
814 IndexPool &indexPool, ValueRange operands) {
815 assert(!operands.empty());
816 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
817 SmallVector<OpFoldResult> targetShape;
818 SmallVector<Value> masterOperands;
819 for (auto dim : llvm::seq<int64_t>(0, rank)) {
820 auto [targetSize, masterOperand] =
821 computeTargetSize(rewriter, loc, indexPool, operands, dim);
822 targetShape.push_back(targetSize);
823 masterOperands.push_back(masterOperand);
824 }
825 return {targetShape, masterOperands};
826}
827
829 IndexPool &indexPool, Value operand,
830 int64_t dim, OpFoldResult targetSize,
831 Value masterOperand) {
832 // Nothing to do if this is a static dimension
833 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
834 if (!rankedTensorType.isDynamicDim(dim))
835 return operand;
836
837 // If the target size for this dimension was directly inferred by only taking
838 // this operand into account, there is no need to broadcast. This is an
839 // optimization that will prevent redundant control flow, and constitutes the
840 // main motivation for tracking "master operands".
841 if (operand == masterOperand)
842 return operand;
843
844 // Affine maps for 'linalg.generic' op
845 auto rank = rankedTensorType.getRank();
846 SmallVector<AffineExpr> affineExprs;
847 for (auto index : llvm::seq<int64_t>(0, rank)) {
848 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
849 : rewriter.getAffineDimExpr(index);
850 affineExprs.push_back(affineExpr);
851 }
852 auto broadcastAffineMap =
853 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
854 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
855 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
856
857 // Check if broadcast is necessary
858 auto one = createIndex(rewriter, loc, indexPool, 1);
859 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
860 auto broadcastNecessary = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
862
863 // Emit 'then' region of 'scf.if'
864 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
865 // It is not safe to cache constants across regions.
866 // New constants could potentially violate dominance requirements.
867 IndexPool localPool;
868
869 // Emit 'tensor.empty' op
870 SmallVector<OpFoldResult> outputTensorShape;
871 for (auto index : llvm::seq<int64_t>(0, rank)) {
872 auto size = index == dim ? targetSize
873 : getOrFoldTensorDim(rewriter, loc, localPool,
874 operand, index);
875 outputTensorShape.push_back(size);
876 }
877 Value outputTensor = tensor::EmptyOp::create(
878 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
879
880 // Emit 'linalg.generic' op
881 auto resultTensor =
882 linalg::GenericOp::create(
883 opBuilder, loc, outputTensor.getType(), operand, outputTensor,
884 affineMaps, getNParallelLoopsAttrs(rank),
885 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
886 // Emit 'linalg.yield' op
887 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
888 })
889 .getResult(0);
890
891 // Cast to original operand type if necessary
892 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
893 loc, operand.getType(), resultTensor);
894
895 // Emit 'scf.yield' op
896 scf::YieldOp::create(opBuilder, loc, castResultTensor);
897 };
898
899 // Emit 'else' region of 'scf.if'
900 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
901 scf::YieldOp::create(opBuilder, loc, operand);
902 };
903
904 // Emit 'scf.if' op
905 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
906 emitThenRegion, emitElseRegion);
907 return ifOp.getResult(0);
908}
909
911 IndexPool &indexPool, Value operand,
912 ArrayRef<OpFoldResult> targetShape,
913 ArrayRef<Value> masterOperands) {
914 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
915 assert((int64_t)targetShape.size() == rank);
916 assert((int64_t)masterOperands.size() == rank);
917 for (auto index : llvm::seq<int64_t>(0, rank))
918 operand =
919 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
920 targetShape[index], masterOperands[index]);
921 return operand;
922}
923
926 IndexPool &indexPool, ValueRange operands,
927 ArrayRef<OpFoldResult> targetShape,
928 ArrayRef<Value> masterOperands) {
929 // No need to broadcast for unary operations
930 if (operands.size() == 1)
931 return operands;
932
933 // No need to broadcast for static shape
934 bool hasDynamic = false;
935 for (auto op : operands) {
936 const auto tType = dyn_cast<RankedTensorType>(op.getType());
937 if (tType && !tType.hasStaticShape()) {
938 hasDynamic = true;
939 break;
940 }
941 }
942 if (!hasDynamic)
943 return operands;
944
945 // Broadcast dynamic dimensions operand by operand
946 return llvm::map_to_vector(operands, [&](Value operand) {
947 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
948 targetShape, masterOperands);
949 });
950}
951
952static LogicalResult
953emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
954 Operation *operation, ValueRange operands,
955 ArrayRef<OpFoldResult> targetShape,
956 const TypeConverter &converter) {
957 // Generate output tensor
958 auto resultType = cast_or_null<RankedTensorType>(
959 converter.convertType(operation->getResultTypes().front()));
960 if (!resultType) {
961 return rewriter.notifyMatchFailure(operation, "failed to convert type");
962 }
963 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
964 resultType.getElementType());
965
966 // Create affine maps. Input affine maps broadcast static dimensions of size
967 // 1. The output affine map is an identity map.
968 //
969 auto rank = resultType.getRank();
970 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
971 auto shape = cast<ShapedType>(operand.getType()).getShape();
972 SmallVector<AffineExpr> affineExprs;
973 for (auto it : llvm::enumerate(shape)) {
974 // Prefer producting identity maps whenever possible (i.e. no broadcasting
975 // needed) because some transforms (like reshape folding)
976 // do not support affine constant exprs.
977 bool requiresBroadcast =
978 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
979 auto affineExpr = requiresBroadcast
980 ? rewriter.getAffineConstantExpr(0)
981 : rewriter.getAffineDimExpr(it.index());
982 affineExprs.push_back(affineExpr);
983 }
984 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
985 });
986 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
987
988 // Emit 'linalg.generic' op
989 bool encounteredError = false;
990 auto linalgOp = linalg::GenericOp::create(
991 rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
993 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
995 operation, blockArgs.take_front(operation->getNumOperands()),
996 {resultType.getElementType()}, rewriter);
997 if (!opResult) {
998 encounteredError = true;
999 return;
1000 }
1001 linalg::YieldOp::create(opBuilder, loc, opResult);
1002 });
1003 if (encounteredError)
1004 return rewriter.notifyMatchFailure(
1005 operation, "unable to create linalg.generic body for elementwise op");
1006
1007 // Cast 'linalg.generic' result into original result type if needed
1008 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1009 loc, resultType, linalgOp->getResult(0));
1010 rewriter.replaceOp(operation, castResult);
1011 return success();
1012}
1013
1015 ValueRange operands) {
1016 // Shift cannot broadcast
1017 if (isa<tosa::MulOp>(operation)) {
1018 DenseElementsAttr shiftElems;
1019 // Shift cannot broadcast when it is constant
1020 if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1021 return operands.take_front(2);
1022 else
1023 return operands.take_front(3);
1024 }
1025 if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1026 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1027 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1028 if (failed(maybeOutZp) && failed(maybeInZp))
1029 return operands;
1030 // Input1_zp and output_zp cannot broadcast when they are constants.
1031 return operands.take_front(1);
1032 }
1033 return operands;
1034}
1035
1036static LogicalResult
1038 ConversionPatternRewriter &rewriter,
1039 const TypeConverter &converter) {
1040
1041 // Collect op properties
1042 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1043 assert(operation->getNumOperands() >= 1 &&
1044 "elementwise op expects at least 1 operand");
1045 if (!operandsAndResultsRanked(operation))
1046 return rewriter.notifyMatchFailure(operation,
1047 "Unranked tensors not supported");
1048
1049 // Lower operation
1050 IndexPool indexPool;
1051 auto loc = operation->getLoc();
1052 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1053 auto [targetShape, masterOperands] =
1054 computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1055 auto broadcastOperands =
1056 broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1057 targetShape, masterOperands);
1058 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1059 targetShape, converter);
1060}
1061
1062// Returns the constant initial value for a given reduction operation. The
1063// attribute type varies depending on the element type required.
1064static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1065 PatternRewriter &rewriter) {
1066 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1067 return rewriter.getFloatAttr(elementTy, 0.0);
1068
1069 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1070 return rewriter.getIntegerAttr(elementTy, 0);
1071
1072 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1073 return rewriter.getFloatAttr(elementTy, 1.0);
1074
1075 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1076 return rewriter.getIntegerAttr(elementTy, 1);
1077
1078 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1079 return rewriter.getFloatAttr(
1080 elementTy, APFloat::getLargest(
1081 cast<FloatType>(elementTy).getFloatSemantics(), false));
1082
1083 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1084 return rewriter.getIntegerAttr(
1085 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1086
1087 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1088 return rewriter.getFloatAttr(
1089 elementTy, APFloat::getLargest(
1090 cast<FloatType>(elementTy).getFloatSemantics(), true));
1091
1092 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1093 return rewriter.getIntegerAttr(
1094 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1095
1096 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1097 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1098
1099 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1100 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1101
1102 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1103 return rewriter.getFloatAttr(
1104 elementTy, APFloat::getLargest(
1105 cast<FloatType>(elementTy).getFloatSemantics(), true));
1106
1107 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1108 return rewriter.getIntegerAttr(
1109 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1110
1111 return {};
1112}
1113
1114// Creates the body calculation for a reduction. The operations vary depending
1115// on the input type.
1117 ValueRange args,
1118 Type elementTy,
1119 PatternRewriter &rewriter) {
1120 Location loc = op->getLoc();
1121 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1122 return arith::AddFOp::create(rewriter, loc, args);
1123 }
1124
1125 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1126 return arith::AddIOp::create(rewriter, loc, args);
1127 }
1128
1129 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1130 return arith::MulFOp::create(rewriter, loc, args);
1131 }
1132
1133 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1134 return arith::MulIOp::create(rewriter, loc, args);
1135 }
1136
1137 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1138 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1139 }
1140
1141 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1142 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1143 }
1144
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1146 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1147 }
1148
1149 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1150 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1151 }
1152
1153 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1154 return arith::AndIOp::create(rewriter, loc, args);
1155
1156 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1157 return arith::OrIOp::create(rewriter, loc, args);
1158
1159 return {};
1160}
1161
1162// Performs the match and rewrite for reduction operations. This includes
1163// declaring a correctly sized initial value, and the linalg.generic operation
1164// that reduces across the specified axis.
1165template <typename OpTy>
1166static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1167 PatternRewriter &rewriter) {
1168 auto loc = op->getLoc();
1169 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1170 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1171 if (!inputTy || !resultTy)
1172 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1173
1174 auto elementTy = resultTy.getElementType();
1175 Value input = op->getOperand(0);
1176
1177 // Figure out the accType if needed
1178 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1179 isa<FloatType>(elementTy) &&
1180 cast<FloatType>(elementTy).isBF16();
1181 Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1182
1183 SmallVector<int64_t> reduceShape;
1184 SmallVector<Value> dynDims;
1185 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1186 if (axis != i) {
1187 reduceShape.push_back(inputTy.getDimSize(i));
1188 if (inputTy.isDynamicDim(i))
1189 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1190 }
1191 }
1192
1193 SmallVector<Value> inputs, outputs;
1194 inputs.push_back(input);
1195
1196 // First fill the output buffer with the init value.
1197 auto emptyTensor =
1198 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1199 .getResult();
1200
1201 auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1202 if (!fillValueAttr)
1203 return rewriter.notifyMatchFailure(
1204 op, "No initial value found for reduction operation");
1205
1206 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1207 auto filledTensor =
1208 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1209 ValueRange{emptyTensor})
1210 .result();
1211 outputs.push_back(filledTensor);
1212
1213 bool isNanIgnoreMode = false;
1214 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1215 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1216 // NaN propagation has no meaning for non floating point types.
1217 if (isa<FloatType>(elementTy) &&
1218 op.getNanMode() == NanPropagationMode::IGNORE) {
1219 isNanIgnoreMode = true;
1220 // Because the TOSA spec requires the result be NaN iff all elements in
1221 // the reduction are NaN we can't simply perform a compare and select.
1222 // Additionally we have to keep track of whether we've seen any non-NaN
1223 // values and then do a final select based on this predicate.
1224 auto trueAttr = rewriter.getBoolAttr(true);
1225 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1226 auto emptyBoolTensor =
1227 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1228 trueValue.getType(), dynDims)
1229 .getResult();
1230 auto allResultsNaNTensor =
1231 linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1232 ValueRange{emptyBoolTensor})
1233 .result();
1234 // Note that because the linalg::ReduceOp has two variadic arguments
1235 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1236 // need to have the same number of inputs and outputs.
1237 //
1238 // The second input isn't actually used anywhere since the value used to
1239 // update the NaN flag is calculated inside the body of the reduction and
1240 // then used to update an out value.
1241 // In order to satisfy type constraints we just pass another copy of the
1242 // input here.
1243 inputs.push_back(input);
1244 outputs.push_back(allResultsNaNTensor);
1245 }
1246 }
1247
1248 bool didEncounterError = false;
1249 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1250 rewriter, loc, inputs, outputs, axis,
1251 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1252 std::array<Value, 2> binaryArgs{
1253 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1254
1255 // If reduction type differs then extend (applicable to reduce_sum)
1256 if (binaryArgs[0].getType() != accTy)
1257 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1258 binaryArgs[0]);
1259
1260 auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1261 accTy, rewriter);
1262 if (result)
1263 didEncounterError = true;
1264
1265 SmallVector<Value> resultsToYield;
1266 if (isNanIgnoreMode) {
1267 auto inputValue = blockArgs[0];
1268 auto initialValue = blockArgs[2];
1269 auto oldAllResultsNanFlagValue = blockArgs[3];
1270
1271 // Unordered comparison of NaN against itself will always return true.
1272 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1273 arith::CmpFPredicate::UNO,
1274 inputValue, inputValue);
1275 // If we've encountered a NaN, take the non-NaN value.
1276 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1277 isNaN, initialValue, result);
1278 // Update the flag which keeps track of whether we have seen a non-NaN
1279 // value.
1280 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1281 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1282 resultsToYield.push_back(selectOp);
1283 resultsToYield.push_back(newAllResultsNanFlagValue);
1284 } else {
1285 resultsToYield.push_back(result);
1286 }
1287 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1288 });
1289
1290 if (!didEncounterError)
1291 return rewriter.notifyMatchFailure(
1292 op, "unable to create linalg.generic body for reduce op");
1293
1294 if (isNanIgnoreMode) {
1295 // Materialize a check to see whether we encountered any non-NaN values, if
1296 // we didn't we need to select a tensor of NaNs since the result will just
1297 // be the initial identity value propagated through all the compares and
1298 // selects inside the reduction.
1299
1300 // Create a tensor full of NaNs.
1301 auto nanValueAttr = rewriter.getFloatAttr(
1302 accTy,
1303 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1304 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1305 auto emptyNanTensor =
1306 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1307 .getResult();
1308 auto nanFilledTensor =
1309 linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1310 ValueRange{emptyNanTensor})
1311 .result();
1312
1313 // Create an empty tensor, non need to fill this since it will be
1314 // overwritten by the select.
1315 auto finalEmptyTensor =
1316 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1317 .getResult();
1318
1319 // Do a selection between the tensors akin to:
1320 // result = NaN if "all results NaN" else result.
1321 SmallVector<Value> ins, outs;
1322 ins.push_back(linalgOp->getOpResult(1));
1323 ins.push_back(nanFilledTensor);
1324 ins.push_back(linalgOp->getResult(0));
1325 outs.push_back(finalEmptyTensor);
1326 auto linalgSelect =
1327 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1328 linalgOp = linalgSelect;
1329 }
1330
1331 // Truncate back to resultTy if needed
1332 Value reducedRes = linalgOp->getResult(0);
1333 if (widenAccTy) {
1334 auto resEmptyOp =
1335 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1336 .getResult();
1337
1338 const unsigned reducedRank =
1339 cast<ShapedType>(reducedRes.getType()).getRank();
1340 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1341 reducedRes =
1342 linalg::GenericOp::create(
1343 rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1344 ValueRange{resEmptyOp},
1345 ArrayRef<AffineMap>{identityMap, identityMap},
1346 getNParallelLoopsAttrs(reducedRank),
1347 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1348 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1349 elementTy, args[0]);
1350 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1351 })
1352 .getResults()[0];
1353 }
1354
1355 SmallVector<ReassociationExprs, 4> reassociationMap;
1356 uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1357 reassociationMap.resize(expandInputRank);
1358
1359 for (uint64_t i = 0; i < expandInputRank; i++) {
1360 int32_t dimToPush = i > axis ? i + 1 : i;
1361 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1362 }
1363
1364 if (expandInputRank != 0) {
1365 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1366 reassociationMap[expandedDim].push_back(
1367 rewriter.getAffineDimExpr(expandedDim + 1));
1368 }
1369
1370 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1371 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1372 // not have access to such information. This matters when handling dynamically
1373 // sized tensors.
1374 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1375 reassociationMap);
1376 return success();
1377}
1378
1379namespace {
1380
1381template <typename SrcOp>
1382class PointwiseConverter : public OpConversionPattern<SrcOp> {
1383public:
1384 using OpConversionPattern<SrcOp>::OpConversionPattern;
1385 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1386
1387 LogicalResult
1388 matchAndRewrite(SrcOp op, OpAdaptor operands,
1389 ConversionPatternRewriter &rewriter) const final {
1391 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1392 }
1393};
1394
1395// Collapse tensor<1xiN> into tensor<iN>
1396// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1398 Location loc) {
1400 // Create the collapsed type
1401 auto inputType = cast<RankedTensorType>(input.getType());
1402 auto elemType = inputType.getElementType();
1403 auto collapsedType = RankedTensorType::get({}, elemType);
1404 // Emit the collapse op
1405 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1406 reassociation);
1407}
1408
1410convertToI8(const llvm::SmallVector<int32_t> &input) {
1412 output.reserve(input.size());
1413
1414 for (auto v : llvm::map_range(
1415 input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1416 output.push_back(v);
1417 }
1418 return output;
1419}
1420
1421// The shift or multiplier may be either constant or non-constant, depending on
1422// whether dynamic extension is enabled.
1423// - If the shift or multiplier is non-constant, add it as an input to
1424// linalg::GenericOp by:
1425// 1. Pushing it into 'genericInputs'.
1426// 2. Appending a corresponding affine map to 'indexingMaps'.
1427// - If the shift or multiplier is constant, set 'constant' instead.
1428static void setupLinalgGenericOpInputAndIndexingMap(
1430 SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431 bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1432 bool isShift = false) {
1433
1434 auto loc = op.getLoc();
1435 auto inputTy = cast<ShapedType>(op.getInput().getType());
1436 unsigned rank = inputTy.getRank();
1437 SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1438
1439 if (isConstant) {
1440 // If we are rescaling per-channel then we need to store the
1441 // values in a buffer.
1442 if (values.size() == 1) {
1443 IntegerAttr intAttr = isShift
1444 ? rewriter.getI8IntegerAttr(values.front())
1445 : rewriter.getI32IntegerAttr(values.front());
1446 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1447 } else {
1448 auto elementType =
1449 isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1450 auto tensorType = RankedTensorType::get(
1451 {static_cast<int64_t>(values.size())}, elementType);
1452 DenseIntElementsAttr EltAttr;
1453 if (isShift)
1454 EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1455 else
1456 EltAttr = DenseIntElementsAttr::get(tensorType, values);
1457 genericInputs.push_back(
1458 arith::ConstantOp::create(rewriter, loc, EltAttr));
1459 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1460 /*symbolCount=*/0, exprs,
1461 rewriter.getContext()));
1462 }
1463 } else {
1464 // If we are not rescaling per-channel then we need to collapse 1xN to N
1465 // and push broadcastMap.
1466 auto operand = isShift ? op.getShift() : op.getMultiplier();
1467 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1468 if (tensorType && tensorType.hasStaticShape() &&
1469 tensorType.getShape()[0] == 1) {
1470 // broadcastMap = affine_map<(d0, d1) -> ()>
1471 // It would affect as broadcast for scalar values in linalg::GenericOp.
1472 AffineMap broadcastMap =
1473 AffineMap::get(rank, 0, {}, rewriter.getContext());
1474 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1475 indexingMaps.push_back(broadcastMap);
1476 } else {
1477 genericInputs.push_back(operand);
1478 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1479 /*symbolCount=*/0, exprs,
1480 rewriter.getContext()));
1481 }
1482 }
1483 arg = indexingMaps.size() - 1;
1484}
1485
1486// Return the extended Zp to be used in subsequent arithmetic operations.
1487static Value getExtendZp(OpBuilder &builder, Type valueTy,
1488 FailureOr<int64_t> maybeZp, Location loc,
1489 ValueRange blockArgs, int64_t zpArg,
1490 bool isOutputZp = false) {
1491 Value result;
1492 const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1493 const uint32_t attrBitwidth =
1494 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1495 auto extendType = builder.getIntegerType(attrBitwidth);
1496 // The Zp value can be either constant or non-constant, depending on
1497 // whether dynamic extension is enabled.
1498 // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1499 // be passed as an input to linalg::GenericOp.
1500 if (failed(maybeZp)) {
1501 result = blockArgs[zpArg];
1502 auto zpTy = result.getType();
1503 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1504 // For ExtUIOp, the input must be signless.
1505 // UnrealizedConversionCastOp will cast the input to signless type.
1506 if (zpTy.isUnsignedInteger()) {
1507 result =
1508 UnrealizedConversionCastOp::create(
1509 builder, loc,
1510 builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1511 .getResult(0);
1512 }
1513 if (zpTy.isUnsignedInteger()) {
1514 return arith::ExtUIOp::create(builder, loc, extendType, result);
1515 } else {
1516 return arith::ExtSIOp::create(builder, loc, extendType, result);
1517 }
1518 }
1519 } else {
1520 return arith::ConstantOp::create(builder, loc,
1521 IntegerAttr::get(extendType, *maybeZp));
1522 }
1523 return result;
1524}
1525
1526class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1527public:
1528 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1529
1530 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1531 PatternRewriter &rewriter) const final {
1532 auto loc = op.getLoc();
1533 auto input = op.getInput();
1534 auto inputTy = cast<ShapedType>(op.getInput().getType());
1535 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1536 unsigned rank = inputTy.getRank();
1537
1538 // This is an illegal configuration. terminate and log an error
1539 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1540 return rewriter.notifyMatchFailure(
1541 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1542 "currently supported");
1543 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1544 return rewriter.notifyMatchFailure(
1545 op, "tosa.rescale requires scale32 for double_round to be true");
1546
1547 if (!isa<IntegerType>(inputTy.getElementType()))
1548 return rewriter.notifyMatchFailure(op, "only support integer type");
1549
1550 SmallVector<Value> dynDims;
1551 for (int i = 0; i < outputTy.getRank(); i++) {
1552 if (outputTy.isDynamicDim(i)) {
1553 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1554 }
1555 }
1556
1557 DenseElementsAttr shiftElems;
1558 bool isShiftConstant = false;
1559 if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1560 isShiftConstant = true;
1561
1562 DenseElementsAttr multiplierElems;
1563 bool isMultiplierConstant = false;
1564 if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1565 isMultiplierConstant = true;
1566
1567 llvm::SmallVector<int32_t> shiftValues;
1568 llvm::SmallVector<int32_t> multiplierValues;
1569 bool doubleRound;
1570
1571 if (isMultiplierConstant && isShiftConstant) {
1572 // explicit cast is required here
1573 shiftValues = llvm::to_vector(llvm::map_range(
1574 shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1575 return static_cast<int32_t>(attr.getInt());
1576 }));
1577 multiplierValues = llvm::to_vector(
1578 llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1579 [](IntegerAttr attr) -> int32_t {
1580 return static_cast<int32_t>(attr.getInt());
1581 }));
1582
1583 // If we shift by more than the bitwidth, this just sets to 0.
1584 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1585 if (shiftValues[i] > 63) {
1586 shiftValues[i] = 0;
1587 multiplierValues[i] = 0;
1588 }
1589 }
1590 // Double round only occurs if shift is greater than 31, check that this
1591 // is ever true.
1592 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1593 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1594 } else
1595 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1596
1597 RoundingMode roundingMode =
1598 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1599
1600 SmallVector<AffineMap> indexingMaps = {
1601 rewriter.getMultiDimIdentityMap(rank)};
1602 SmallVector<Value, 4> genericInputs = {input};
1603
1604 // If we are rescaling per-channel then we need to store the multiplier
1605 // values in a buffer.
1606 Value multiplierConstant;
1607 int64_t multiplierArg = 0;
1608 setupLinalgGenericOpInputAndIndexingMap(
1609 rewriter, multiplierValues, genericInputs, indexingMaps,
1610 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1611
1612 // If we are rescaling per-channel then we need to store the shift
1613 // values in a buffer.
1614 Value shiftConstant;
1615 int64_t shiftArg = 0;
1616 setupLinalgGenericOpInputAndIndexingMap(
1617 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1618 shiftConstant, shiftArg, true);
1619
1620 // broadcastMap = affine_map<(d0, d1) -> ()>
1621 // It would affect as broadcast for scalar values in linalg::GenericOp.
1622 AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1623 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1624 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1625 // The inputZp and outputZp may be either constant or non-constant,
1626 // depending on whether dynamic extension is enabled.
1627 // - If the zp's are non-constant, add them as an inputs to
1628 // linalg::GenericOp by:
1629 // 1. Pushing it into 'genericInputs'.
1630 // 2. Appending a corresponding affine map to 'indexingMaps'.
1631 // - If the zp's are constant, they would be generated as arith.constant.
1632 int64_t iZpArg = 0;
1633 if (failed(maybeIZp)) {
1634 genericInputs.push_back(
1635 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1636 indexingMaps.push_back(broadcastMap);
1637 iZpArg = indexingMaps.size() - 1;
1638 }
1639 int64_t oZpArg = 0;
1640 if (failed(maybeOZp)) {
1641 genericInputs.push_back(
1642 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1643 indexingMaps.push_back(broadcastMap);
1644 oZpArg = indexingMaps.size() - 1;
1645 }
1646
1647 // Indexing maps for output values.
1648 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1649
1650 // Construct the indexing maps needed for linalg.generic ops.
1651 Value emptyTensor = tensor::EmptyOp::create(
1652 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1653 ArrayRef<Value>({dynDims}));
1654
1655 auto linalgOp = linalg::GenericOp::create(
1656 rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1657 indexingMaps, getNParallelLoopsAttrs(rank),
1658 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1659 ValueRange blockArgs) {
1660 Value value = blockArgs[0];
1661 Type valueTy = value.getType();
1662
1663 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1664 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1665 nestedLoc, blockArgs, iZpArg);
1666
1667 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1668 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1669 nestedLoc, blockArgs, oZpArg, true);
1670
1671 IntegerType outIntType =
1672 cast<IntegerType>(blockArgs.back().getType());
1673 unsigned outBitWidth = outIntType.getWidth();
1674 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1675
1676 Value multiplier = multiplierConstant ? multiplierConstant
1677 : blockArgs[multiplierArg];
1678 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1679
1680 if (valueTy.isUnsignedInteger()) {
1681 value = UnrealizedConversionCastOp::create(
1682 nestedBuilder, nestedLoc,
1683 nestedBuilder.getIntegerType(
1684 valueTy.getIntOrFloatBitWidth()),
1685 value)
1686 .getResult(0);
1687 }
1688 if (valueTy.getIntOrFloatBitWidth() < 32) {
1689 if (op.getInputUnsigned()) {
1690 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1691 nestedBuilder.getI32Type(), value);
1692 } else {
1693 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1694 nestedBuilder.getI32Type(), value);
1695 }
1696 }
1697
1698 value =
1699 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1700
1701 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1702 nestedBuilder.getI32Type(), value,
1703 multiplier, shift, roundingMode);
1704
1705 // Move to the new zero-point.
1706 value =
1707 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1708
1709 // Saturate to the output size.
1710 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1711 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1712
1713 // Unsigned integers have a difference output value.
1714 if (op.getOutputUnsigned()) {
1715 intMin = 0;
1716 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1717 }
1718
1719 auto intMinVal = arith::ConstantOp::create(
1720 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1721 auto intMaxVal = arith::ConstantOp::create(
1722 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1723
1724 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1725 nestedBuilder, /*isUnsigned=*/false);
1726
1727 if (outIntType.getWidth() < 32) {
1728 value = arith::TruncIOp::create(
1729 nestedBuilder, nestedLoc,
1730 rewriter.getIntegerType(outIntType.getWidth()), value);
1731 }
1732
1733 if (outIntType.isUnsignedInteger()) {
1734 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1735 outIntType, value)
1736 .getResult(0);
1737 }
1738 linalg::YieldOp::create(nestedBuilder, loc, value);
1739 });
1740
1741 rewriter.replaceOp(op, linalgOp->getResults());
1742 return success();
1743 }
1744};
1745
1746// Handle the resize case where the input is a 1x1 image. This case
1747// can entirely avoiding having extract operations which target much
1748// more difficult to optimize away.
1749class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1750public:
1751 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1752
1753 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1754 PatternRewriter &rewriter) const final {
1755 Location loc = op.getLoc();
1756 ImplicitLocOpBuilder builder(loc, rewriter);
1757 auto input = op.getInput();
1758 auto inputTy = cast<RankedTensorType>(input.getType());
1759 auto resultTy = cast<RankedTensorType>(op.getType());
1760 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1761
1762 auto inputH = inputTy.getDimSize(1);
1763 auto inputW = inputTy.getDimSize(2);
1764 auto outputH = resultTy.getDimSize(1);
1765 auto outputW = resultTy.getDimSize(2);
1766
1767 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1768 return rewriter.notifyMatchFailure(
1769 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1770
1771 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1772 op.getMode() != ResizeMode::BILINEAR)
1773 return rewriter.notifyMatchFailure(
1774 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1775
1776 if (inputTy == resultTy) {
1777 rewriter.replaceOp(op, input);
1778 return success();
1779 }
1780
1781 SmallVector<int64_t> scale;
1782 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1783 return failure();
1784 }
1785
1786 // Collapse the unit width and height away.
1787 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1788 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1789 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1790 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1791 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1792
1793 auto collapseTy =
1794 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1795 inputTy.getElementType());
1796 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1797 reassociationMap);
1798
1799 // Get any dynamic shapes that appear in the input format.
1800 llvm::SmallVector<Value> outputDynSize;
1801 if (inputTy.isDynamicDim(0))
1802 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1803 if (inputTy.isDynamicDim(3))
1804 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1805
1806 // Generate the elementwise operation for casting scaling the input value.
1807 auto genericTy = collapseTy.clone(resultTy.getElementType());
1808 Value empty =
1809 tensor::EmptyOp::create(builder, genericTy.getShape(),
1810 resultTy.getElementType(), outputDynSize);
1811 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1812 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1813 utils::IteratorType::parallel);
1814
1815 auto generic = linalg::GenericOp::create(
1816 builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1817 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1818 [=](OpBuilder &b, Location loc, ValueRange args) {
1819 Value value = args[0];
1820 // This is the quantized case.
1821 if (inputTy.getElementType() != resultTy.getElementType()) {
1822 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1823 value);
1824
1825 if (isBilinear && scale[0] != 0) {
1826 Value scaleY = arith::ConstantOp::create(
1827 b, loc, b.getI32IntegerAttr(scale[0]));
1828 value = arith::MulIOp::create(b, loc, value, scaleY);
1829 }
1830
1831 if (isBilinear && scale[2] != 0) {
1832 Value scaleX = arith::ConstantOp::create(
1833 b, loc, b.getI32IntegerAttr(scale[2]));
1834 value = arith::MulIOp::create(b, loc, value, scaleX);
1835 }
1836 }
1837
1838 linalg::YieldOp::create(b, loc, value);
1839 });
1840
1841 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1842 op, resultTy, generic.getResults()[0], reassociationMap);
1843 return success();
1844 }
1845};
1846
1847// TOSA resize with width or height of 1 may be broadcasted to a wider
1848// dimension. This is done by materializing a new tosa.resize without
1849// the broadcasting behavior, and an explicit broadcast afterwards.
1850class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1851public:
1852 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1853
1854 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1855 PatternRewriter &rewriter) const final {
1856 Location loc = op.getLoc();
1857 ImplicitLocOpBuilder builder(loc, rewriter);
1858 auto input = op.getInput();
1859 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1860 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1861
1862 if (!inputTy || !resultTy)
1863 return rewriter.notifyMatchFailure(op,
1864 "requires ranked input/output types");
1865
1866 auto batch = inputTy.getDimSize(0);
1867 auto channels = inputTy.getDimSize(3);
1868 auto inputH = inputTy.getDimSize(1);
1869 auto inputW = inputTy.getDimSize(2);
1870 auto outputH = resultTy.getDimSize(1);
1871 auto outputW = resultTy.getDimSize(2);
1872
1873 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1874 return rewriter.notifyMatchFailure(
1875 op, "tosa.resize has no broadcasting behavior");
1876
1877 // For any dimension that is broadcastable we generate a width of 1
1878 // on the output.
1879 llvm::SmallVector<int64_t> resizeShape;
1880 resizeShape.push_back(batch);
1881 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1882 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1883 resizeShape.push_back(channels);
1884
1885 auto resizeTy = resultTy.clone(resizeShape);
1886 auto resize =
1887 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1888 op.getOffset(), op.getBorder(), op.getMode());
1889
1890 // Collapse an unit result dims.
1891 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1892 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1893 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1894 if (inputH != 1)
1895 reassociationMap.push_back({});
1896 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1897 if (inputW != 1)
1898 reassociationMap.push_back({});
1899 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1900
1901 llvm::SmallVector<int64_t> collapseShape = {batch};
1902 if (inputH != 1)
1903 collapseShape.push_back(outputH);
1904 if (inputW != 1)
1905 collapseShape.push_back(outputW);
1906 collapseShape.push_back(channels);
1907
1908 auto collapseTy = resultTy.clone(collapseShape);
1909 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1910 resize, reassociationMap);
1911
1912 // Broadcast the collapsed shape to the output result.
1913 llvm::SmallVector<Value> outputDynSize;
1914 if (inputTy.isDynamicDim(0))
1915 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1916 if (inputTy.isDynamicDim(3))
1917 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1918
1919 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1920 utils::IteratorType::parallel);
1921 Value empty = tensor::EmptyOp::create(
1922 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1923
1924 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1925 if (inputH != 1)
1926 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1927 if (inputW != 1)
1928 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1929 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1930
1931 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1932 inputExprs, rewriter.getContext());
1933
1934 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1935 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1936 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1937 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1938 [=](OpBuilder &b, Location loc, ValueRange args) {
1939 Value value = args[0];
1940 linalg::YieldOp::create(b, loc, value);
1941 });
1942
1943 return success();
1944 }
1945};
1946
1947class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1948public:
1949 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1950
1951 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1952 PatternRewriter &rewriter) const final {
1953 Location loc = op.getLoc();
1954 ImplicitLocOpBuilder b(loc, rewriter);
1955 auto input = op.getInput();
1956 auto inputTy = cast<ShapedType>(input.getType());
1957 auto resultTy = cast<ShapedType>(op.getType());
1958 auto resultETy = resultTy.getElementType();
1959
1960 bool floatingPointMode = isa<FloatType>(resultETy);
1961 auto floatTy = resultETy;
1962
1963 auto imageH = inputTy.getShape()[1];
1964 auto imageW = inputTy.getShape()[2];
1965
1966 auto dynamicDimsOr =
1967 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1968 if (!dynamicDimsOr.has_value())
1969 return rewriter.notifyMatchFailure(
1970 op, "unable to get dynamic dimensions of tosa.resize");
1971
1972 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1973 op.getMode() != ResizeMode::BILINEAR)
1974 return rewriter.notifyMatchFailure(
1975 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1976
1977 SmallVector<AffineMap, 2> affineMaps = {
1978 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1979 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1980 resultETy, *dynamicDimsOr);
1981 auto genericOp = linalg::GenericOp::create(
1982 b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1983 getNParallelLoopsAttrs(resultTy.getRank()));
1984 Value resize = genericOp.getResult(0);
1985
1986 {
1987 OpBuilder::InsertionGuard regionGuard(b);
1988 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1989 TypeRange({resultETy}), loc);
1990 Value batch = linalg::IndexOp::create(b, 0);
1991 Value y = linalg::IndexOp::create(b, 1);
1992 Value x = linalg::IndexOp::create(b, 2);
1993 Value channel = linalg::IndexOp::create(b, 3);
1994
1995 Value zeroI32 =
1996 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1997 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1998 Value hMax =
1999 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
2000 Value wMax =
2001 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
2002
2003 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
2004 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
2005
2006 SmallVector<int64_t> scale, offset, border;
2007 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
2008 !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
2009 !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
2010 return rewriter.notifyMatchFailure(
2011 op, "tosa.resize scale/offset/border should have compile time "
2012 "constant values.");
2013 }
2014
2015 Value yScaleN, yScaleD, xScaleN, xScaleD;
2016 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2017 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2018 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2019 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2020
2021 Value yOffset, xOffset, yBorder, xBorder;
2022 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2023 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2024 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2025 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2026
2027 // Compute the ix and dx values for both the X and Y dimensions.
2028 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2029 Value scaleN, Value scaleD, Value offset,
2030 int size, ImplicitLocOpBuilder &b) {
2031 if (size == 1) {
2032 index = zeroI32;
2033 delta = zeroFp;
2034 return;
2035 }
2036 // x = x * scale_d + offset;
2037 // ix = floor(x / scale_n)
2038 Value val = arith::MulIOp::create(b, in, scaleD);
2039 val = arith::AddIOp::create(b, val, offset);
2040 index = arith::FloorDivSIOp::create(b, val, scaleN);
2041
2042 // rx = x % scale_n
2043 // dx = rx / scale_n
2044 Value r = arith::RemSIOp::create(b, val, scaleN);
2045 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2046 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2047 delta = arith::DivFOp::create(b, rFp, scaleNfp);
2048 };
2049
2050 // Compute the ix and dx values for the X and Y dimensions - int case.
2051 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2052 Value scaleN, Value scaleD, Value offset,
2053 int size, ImplicitLocOpBuilder &b) {
2054 if (size == 1) {
2055 index = zeroI32;
2056 delta = zeroI32;
2057 return;
2058 }
2059 // x = x * scale_d + offset;
2060 // ix = floor(x / scale_n)
2061 // dx = x - ix * scale_n;
2062 Value val = arith::MulIOp::create(b, in, scaleD);
2063 val = arith::AddIOp::create(b, val, offset);
2064 index = arith::DivSIOp::create(b, val, scaleN);
2065 delta = arith::MulIOp::create(b, index, scaleN);
2066 delta = arith::SubIOp::create(b, val, delta);
2067 };
2068
2069 Value ix, iy, dx, dy;
2070 if (floatingPointMode) {
2071 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2072 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2073 } else {
2074 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2075 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2076 }
2077
2078 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2079 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2080
2081 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2082 Value max, int size,
2083 ImplicitLocOpBuilder &b) -> Value {
2084 if (size == 1) {
2086 }
2087
2088 Value pred;
2089 if (floatingPointMode) {
2090 auto h =
2091 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2092 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2093 } else {
2094 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2095 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2096 dvalDouble, scale);
2097 }
2098
2099 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2100 val = arith::AddIOp::create(b, val, offset);
2101 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
2102 return arith::IndexCastOp::create(b, b.getIndexType(), val);
2103 };
2104
2105 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2106 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2107
2108 Value result = tensor::ExtractOp::create(
2109 b, input, ValueRange{batch, iy, ix, channel});
2110
2111 linalg::YieldOp::create(b, result);
2112 } else {
2113 // The mode here must be BILINEAR.
2114 assert(op.getMode() == ResizeMode::BILINEAR);
2115
2116 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2117
2118 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
2119 Value max, ImplicitLocOpBuilder &b) {
2120 val0 = in;
2121 val1 = arith::AddIOp::create(b, val0, oneVal);
2122 val0 =
2123 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
2124 val1 =
2125 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
2126 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2127 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2128 };
2129
2130 // Linalg equivalent to the section below:
2131 // int16_t iy0 = apply_max(iy, 0);
2132 // int16_t iy1 = apply_min(iy + 1, IH - 1);
2133 // int16_t ix0 = apply_max(ix, 0);
2134 // int16_t ix1 = apply_min(ix + 1, IW - 1);
2135 Value x0, x1, y0, y1;
2136 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2137 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2138
2139 Value y0x0 = tensor::ExtractOp::create(
2140 b, input, ValueRange{batch, y0, x0, channel});
2141 Value y0x1 = tensor::ExtractOp::create(
2142 b, input, ValueRange{batch, y0, x1, channel});
2143 Value y1x0 = tensor::ExtractOp::create(
2144 b, input, ValueRange{batch, y1, x0, channel});
2145 Value y1x1 = tensor::ExtractOp::create(
2146 b, input, ValueRange{batch, y1, x1, channel});
2147
2148 if (floatingPointMode) {
2149 auto oneVal =
2150 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2151 auto interpolate = [&](Value val0, Value val1, Value delta,
2152 int inputSize,
2153 ImplicitLocOpBuilder &b) -> Value {
2154 if (inputSize == 1)
2155 return val0;
2156 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2157 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2158 Value mul1 = arith::MulFOp::create(b, val1, delta);
2159 return arith::AddFOp::create(b, mul0, mul1);
2160 };
2161
2162 // Linalg equivalent to the section below:
2163 // topAcc = v00 * (unit_x - dx);
2164 // topAcc += v01 * dx;
2165 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2166
2167 // Linalg equivalent to the section below:
2168 // bottomAcc = v10 * (unit_x - dx);
2169 // bottomAcc += v11 * dx;
2170 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2171
2172 // Linalg equivalent to the section below:
2173 // result = topAcc * (unit_y - dy) + bottomAcc * dy
2174 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2175 linalg::YieldOp::create(b, result);
2176 } else {
2177 // Perform in quantized space.
2178 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2179 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2180 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2181 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2182
2183 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2184 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2185 dx = arith::ExtSIOp::create(b, resultETy, dx);
2186 dy = arith::ExtSIOp::create(b, resultETy, dy);
2187 }
2188
2189 Value yScaleNExt = yScaleN;
2190 Value xScaleNExt = xScaleN;
2191
2192 const int64_t scaleBitwidth =
2193 xScaleN.getType().getIntOrFloatBitWidth();
2194 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2195 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2196 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2197 }
2198
2199 auto interpolate = [](Value val0, Value val1, Value weight1,
2200 Value scale, int inputSize,
2201 ImplicitLocOpBuilder &b) -> Value {
2202 if (inputSize == 1)
2203 return arith::MulIOp::create(b, val0, scale);
2204 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2205 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2206 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2207 return arith::AddIOp::create(b, mul0, mul1);
2208 };
2209
2210 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2211 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2212 Value result =
2213 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2214 linalg::YieldOp::create(b, result);
2215 }
2216 }
2217 }
2218
2219 rewriter.replaceOp(op, resize);
2220 return success();
2221 }
2222};
2223
2224// At the codegen level any identity operations should be removed. Any cases
2225// where identity is load-bearing (e.g. cross device computation) should be
2226// handled before lowering to codegen.
2227template <typename SrcOp>
2228class IdentityNConverter : public OpRewritePattern<SrcOp> {
2229public:
2230 using OpRewritePattern<SrcOp>::OpRewritePattern;
2231
2232 LogicalResult matchAndRewrite(SrcOp op,
2233 PatternRewriter &rewriter) const final {
2234 rewriter.replaceOp(op, op.getOperation()->getOperands());
2235 return success();
2236 }
2237};
2238
2239template <typename SrcOp>
2240class ReduceConverter : public OpRewritePattern<SrcOp> {
2241public:
2242 using OpRewritePattern<SrcOp>::OpRewritePattern;
2243
2244 LogicalResult matchAndRewrite(SrcOp reduceOp,
2245 PatternRewriter &rewriter) const final {
2246 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2247 }
2248};
2249
2250class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2251public:
2252 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2253
2254 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2255 PatternRewriter &rewriter) const final {
2256 auto loc = op.getLoc();
2257 Value input = op.getInput1();
2258 auto inputTy = cast<ShapedType>(input.getType());
2259 auto resultTy = cast<ShapedType>(op.getType());
2260 auto axis = op.getAxis();
2261
2262 SmallVector<Value> dynDims;
2263 for (int i = 0; i < inputTy.getRank(); i++) {
2264 if (inputTy.isDynamicDim(i)) {
2265 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2266 }
2267 }
2268
2269 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2270
2271 // First fill the output buffer with the init value.
2272 auto emptyTensor = tensor::EmptyOp::create(
2273 rewriter, loc, inputTy.getShape(),
2274 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2275 .getResult();
2276 SmallVector<AffineMap, 2> affineMaps = {
2277 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2278
2279 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2280 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2281 getNParallelLoopsAttrs(resultTy.getRank()),
2282 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2283 llvm::SmallVector<Value> indices;
2284 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2285 Value index =
2286 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2287 if (i == axis) {
2288 auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2289 auto sizeMinusOne =
2290 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2291 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2292 index);
2293 }
2294
2295 indices.push_back(index);
2296 }
2297
2298 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2299 input, indices);
2300 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2301 extract.getResult());
2302 });
2303 return success();
2304 }
2305};
2306
2307// This converter translate a tile operation to a reshape, broadcast, reshape.
2308// The first reshape minimally expands each tiled dimension to include a
2309// proceding size-1 dim. This dim is then broadcasted to the appropriate
2310// multiple.
2311struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2312 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2313
2314 LogicalResult
2315 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2316 ConversionPatternRewriter &rewriter) const override {
2317 auto loc = op.getLoc();
2318 auto input = op.getInput1();
2319 auto inputTy = cast<ShapedType>(input.getType());
2320 auto inputShape = inputTy.getShape();
2321 auto resultTy = cast<ShapedType>(op.getType());
2322 auto elementTy = inputTy.getElementType();
2323 int64_t rank = inputTy.getRank();
2324
2325 SmallVector<int64_t> multiples;
2326 if (failed(op.getConstantMultiples(multiples)))
2327 return failure();
2328
2329 // Broadcast the newly added dimensions to their appropriate multiple.
2330 SmallVector<int64_t, 2> genericShape;
2331 for (int i = 0; i < rank; i++) {
2332 int64_t dim = multiples[i];
2333 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2334 genericShape.push_back(inputShape[i]);
2335 }
2336
2337 SmallVector<Value> dynDims;
2338 for (int i = 0; i < inputTy.getRank(); i++) {
2339 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2340 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2341 }
2342 }
2343
2344 auto emptyTensor = tensor::EmptyOp::create(
2345 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2346
2347 // We needs to map the input shape to the non-broadcasted dimensions.
2348 SmallVector<AffineExpr, 4> dimExprs;
2349 dimExprs.reserve(rank);
2350 for (unsigned i = 0; i < rank; ++i)
2351 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2352
2353 auto readAffineMap =
2354 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2355 rewriter.getContext());
2356
2357 SmallVector<AffineMap, 2> affineMaps = {
2358 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2359
2360 auto genericOp = linalg::GenericOp::create(
2361 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2362 ValueRange{emptyTensor}, affineMaps,
2363 getNParallelLoopsAttrs(genericShape.size()),
2364 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2365 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2366 });
2367
2368 auto shapeValue = getTosaConstShape(
2369 rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2370 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2371 op, resultTy, genericOp.getResult(0), shapeValue);
2372 return success();
2373 }
2374};
2375
2376// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2377// op, producing two output buffers.
2378//
2379// The first output buffer contains the index of the found maximum value. It is
2380// initialized to 0 and is resulting integer type.
2381//
2382// The second output buffer contains the maximum value found. It is initialized
2383// to the minimum representable value of the input element type. After being
2384// populated by indexed_generic, this buffer is disgarded as only the index is
2385// requested.
2386//
2387// The indexed_generic op updates both the maximum value and index if the
2388// current value exceeds the running max.
2389class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2390public:
2391 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2392
2393 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2394 PatternRewriter &rewriter) const final {
2395 auto loc = argmaxOp.getLoc();
2396 Value input = argmaxOp.getInput();
2397 auto inputTy = cast<ShapedType>(input.getType());
2398 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2399 auto inElementTy = inputTy.getElementType();
2400 auto outElementTy = resultTy.getElementType();
2401 int axis = argmaxOp.getAxis();
2402 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2403
2404 if (!isa<IntegerType>(outElementTy))
2405 return rewriter.notifyMatchFailure(
2406 argmaxOp,
2407 "tosa.arg_max to linalg.* requires integer-like result type");
2408
2409 SmallVector<Value> dynDims;
2410 for (int i = 0; i < inputTy.getRank(); i++) {
2411 if (inputTy.isDynamicDim(i) && i != axis) {
2412 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2413 }
2414 }
2415
2416 // First fill the output buffer for the index.
2417 auto emptyTensorIdx =
2418 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2419 outElementTy, dynDims)
2420 .getResult();
2421 auto fillValueIdx = arith::ConstantOp::create(
2422 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2423 auto filledTensorIdx =
2424 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2425 ValueRange{emptyTensorIdx})
2426 .result();
2427
2428 // Second fill the output buffer for the running max.
2429 auto emptyTensorMax =
2430 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2431 dynDims)
2432 .getResult();
2433 auto fillValueMaxAttr =
2434 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2435
2436 if (!fillValueMaxAttr)
2437 return rewriter.notifyMatchFailure(
2438 argmaxOp, "unsupported tosa.argmax element type");
2439
2440 auto fillValueMax =
2441 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2442 auto filledTensorMax =
2443 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2444 ValueRange{emptyTensorMax})
2445 .result();
2446
2447 // We need to reduce along the arg-max axis, with parallel operations along
2448 // the rest.
2449 SmallVector<utils::IteratorType, 4> iteratorTypes;
2450 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2451 iteratorTypes[axis] = utils::IteratorType::reduction;
2452
2453 SmallVector<AffineExpr, 2> srcExprs;
2454 SmallVector<AffineExpr, 2> dstExprs;
2455 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2456 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2457 if (axis != i)
2458 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2459 }
2460
2461 bool didEncounterError = false;
2462 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2463 rewriter.getContext());
2464 auto linalgOp = linalg::GenericOp::create(
2465 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2466 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2467 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2468 ValueRange blockArgs) {
2469 auto newValue = blockArgs[0];
2470 auto oldIndex = blockArgs[1];
2471 auto oldValue = blockArgs[2];
2472
2473 Value newIndex = arith::IndexCastOp::create(
2474 rewriter, nestedLoc, oldIndex.getType(),
2475 linalg::IndexOp::create(rewriter, loc, axis));
2476
2477 Value predicate;
2478 if (isa<FloatType>(inElementTy)) {
2479 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2480 // Only update index & max value for non NaN values. If all
2481 // values are NaNs, the initial index will be return which is 0.
2482 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2483 arith::CmpFPredicate::OGT,
2484 newValue, oldValue);
2485 } else {
2486 // Update max value if either of the following is true:
2487 // - new value is bigger
2488 // - cur max is not NaN and new value is NaN
2489 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2490 arith::CmpFPredicate::UGT,
2491 newValue, oldValue);
2492 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2493 arith::CmpFPredicate::ORD,
2494 oldValue, oldValue);
2495 predicate = arith::AndIOp::create(
2496 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2497 }
2498 } else if (isa<IntegerType>(inElementTy)) {
2499 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2500 arith::CmpIPredicate::sgt,
2501 newValue, oldValue);
2502 } else {
2503 didEncounterError = true;
2504 return;
2505 }
2506
2507 auto resultMax = arith::SelectOp::create(
2508 rewriter, nestedLoc, predicate, newValue, oldValue);
2509 auto resultIndex = arith::SelectOp::create(
2510 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2511 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2512 ValueRange({resultIndex, resultMax}));
2513 });
2514
2515 if (didEncounterError)
2516 return rewriter.notifyMatchFailure(
2517 argmaxOp, "unsupported tosa.argmax element type");
2518
2519 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2520 return success();
2521 }
2522};
2523
2524class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2525public:
2526 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2527 LogicalResult
2528 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2529 ConversionPatternRewriter &rewriter) const final {
2530 auto input = adaptor.getOperands()[0];
2531 auto indices = adaptor.getOperands()[1];
2532
2533 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2534 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2535 if (!valuesTy || !resultTy)
2536 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2537
2538 auto dynamicDims = inferDynamicDimsForGather(
2539 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2540
2541 auto resultElementTy = resultTy.getElementType();
2542
2543 auto loc = op.getLoc();
2544 auto emptyTensor =
2545 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2546 resultElementTy, dynamicDims)
2547 .getResult();
2548
2549 SmallVector<AffineMap, 2> affineMaps = {
2551 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2552 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2553 rewriter.getContext()),
2554 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2555
2556 auto genericOp = linalg::GenericOp::create(
2557 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2558 ValueRange{emptyTensor}, affineMaps,
2559 getNParallelLoopsAttrs(resultTy.getRank()),
2560 [&](OpBuilder &b, Location loc, ValueRange args) {
2561 auto indexValue = args[0];
2562 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2563 Value index1 = arith::IndexCastOp::create(
2564 rewriter, loc, rewriter.getIndexType(), indexValue);
2565 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2566 Value extract = tensor::ExtractOp::create(
2567 rewriter, loc, input, ValueRange{index0, index1, index2});
2568 linalg::YieldOp::create(rewriter, loc, extract);
2569 });
2570 rewriter.replaceOp(op, genericOp.getResult(0));
2571 return success();
2572 }
2573
2574 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2575 Location loc,
2576 Value values,
2577 Value indices) {
2578 llvm::SmallVector<Value> results;
2579
2580 auto addDynamicDimension = [&](Value source, int64_t dim) {
2581 auto sz = tensor::getMixedSize(builder, loc, source, dim);
2582 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2583 results.push_back(dimValue);
2584 };
2585
2586 addDynamicDimension(values, 0);
2587 addDynamicDimension(indices, 1);
2588 addDynamicDimension(values, 2);
2589 return results;
2590 }
2591};
2592
2593// Lowerings the TableOp to a series of gathers and numerica operations. This
2594// includes interpolation between the high/low values. For the I8 varient, this
2595// simplifies to a single gather operation.
2596class TableConverter : public OpRewritePattern<tosa::TableOp> {
2597public:
2598 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2599
2600 LogicalResult matchAndRewrite(tosa::TableOp op,
2601 PatternRewriter &rewriter) const final {
2602 auto loc = op.getLoc();
2603 Value input = op.getInput1();
2604 Value table = op.getTable();
2605 auto inputTy = cast<ShapedType>(input.getType());
2606 auto tableTy = cast<ShapedType>(table.getType());
2607 auto resultTy = cast<ShapedType>(op.getType());
2608
2609 auto inputElementTy = inputTy.getElementType();
2610 auto tableElementTy = tableTy.getElementType();
2611 auto resultElementTy = resultTy.getElementType();
2612
2613 SmallVector<Value> dynDims;
2614 for (int i = 0; i < resultTy.getRank(); ++i) {
2615 if (inputTy.isDynamicDim(i)) {
2616 dynDims.push_back(
2617 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2618 }
2619 }
2620
2621 auto emptyTensor =
2622 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2623 resultElementTy, dynDims)
2624 .getResult();
2625
2626 SmallVector<AffineMap, 2> affineMaps = {
2627 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2628 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2629
2630 auto genericOp = linalg::GenericOp::create(
2631 rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2632 affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2633 rewriter.replaceOp(op, genericOp.getResult(0));
2634
2635 {
2636 OpBuilder::InsertionGuard regionGuard(rewriter);
2637 Block *block = rewriter.createBlock(
2638 &genericOp.getRegion(), genericOp.getRegion().end(),
2639 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2640
2641 auto inputValue = block->getArgument(0);
2642 rewriter.setInsertionPointToStart(block);
2643 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2644 resultElementTy.isInteger(8)) {
2645 Value index = arith::IndexCastOp::create(
2646 rewriter, loc, rewriter.getIndexType(), inputValue);
2647 Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2648 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2649 index, offset);
2650 Value extract =
2651 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2652 linalg::YieldOp::create(rewriter, loc, extract);
2653 return success();
2654 }
2655
2656 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2657 resultElementTy.isInteger(32)) {
2658 Value extend = arith::ExtSIOp::create(
2659 rewriter, loc, rewriter.getI32Type(), inputValue);
2660
2661 auto offset = arith::ConstantOp::create(
2662 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2663 auto seven = arith::ConstantOp::create(rewriter, loc,
2664 rewriter.getI32IntegerAttr(7));
2665 auto one = arith::ConstantOp::create(rewriter, loc,
2666 rewriter.getI32IntegerAttr(1));
2667 auto b1111111 = arith::ConstantOp::create(
2668 rewriter, loc, rewriter.getI32IntegerAttr(127));
2669
2670 // Compute the index and fractional part from the input value:
2671 // value = value + 32768
2672 // index = value >> 7;
2673 // fraction = 0x01111111 & value
2674 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2675 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2676 Value fraction =
2677 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2678
2679 // Extract the base and next values from the table.
2680 // base = (int32_t) table[index];
2681 // next = (int32_t) table[index + 1];
2682 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2683
2684 index = arith::IndexCastOp::create(rewriter, loc,
2685 rewriter.getIndexType(), index);
2686 indexPlusOne = arith::IndexCastOp::create(
2687 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2688
2689 Value base =
2690 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2691 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2692 ValueRange{indexPlusOne});
2693
2694 base =
2695 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2696 next =
2697 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2698
2699 // Use the fractional part to interpolate between the input values:
2700 // result = (base << 7) + (next - base) * fraction
2701 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2702 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2703 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2704 Value result =
2705 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2706
2707 linalg::YieldOp::create(rewriter, loc, result);
2708
2709 return success();
2710 }
2711 }
2712
2713 return rewriter.notifyMatchFailure(
2714 op, "unable to create body for tosa.table op");
2715 }
2716};
2717
2718struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2719 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2720
2721 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2722
2723 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2724 OpFoldResult ofr) {
2725 auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2726 auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2727
2728 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2729 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2730 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2731 return getAsOpFoldResult(plusOne);
2732 }
2733
2734 static RankedTensorType
2735 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2736 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2737 // Get [N, H, W]
2738 auto dims = tensor::getMixedSizes(builder, loc, input);
2739
2740 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2741 // output tensors.
2742 dims[2] = halfPlusOne(builder, loc, dims[2]);
2743
2744 llvm::SmallVector<int64_t, 3> staticSizes;
2745 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2746
2747 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2748 return RankedTensorType::get(staticSizes, elementType);
2749 }
2750
2751 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2752 RankedTensorType type,
2753 llvm::ArrayRef<Value> dynamicSizes) {
2754 auto emptyTensor =
2755 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2756 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2757 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2758 auto filledTensor =
2759 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2760 ValueRange{emptyTensor})
2761 .result();
2762 return filledTensor;
2763 }
2764
2765 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2766 FloatType type, Value value) {
2767 auto integerVal = arith::IndexCastUIOp::create(
2768 builder, loc,
2769 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2770 : builder.getI32Type(),
2771 value);
2772
2773 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2774 }
2775
2776 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2777 FloatType type, int64_t index) {
2778 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2779 return castIndexToFloat(builder, loc, type, indexVal);
2780 }
2781
2782 template <typename... Args>
2783 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2784 Args... args) {
2785 return {builder.getAffineDimExpr(args)...};
2786 }
2787
2788 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2789 PatternRewriter &rewriter) const override {
2790 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2791 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2792 return rewriter.notifyMatchFailure(rfft2d,
2793 "only supports ranked tensors");
2794 }
2795
2796 auto loc = rfft2d.getLoc();
2797 auto input = rfft2d.getInputReal();
2798 auto elementType =
2799 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2800 if (!elementType)
2801 return rewriter.notifyMatchFailure(rfft2d,
2802 "only supports float element types");
2803
2804 // Compute the output type and set of dynamic sizes
2805 llvm::SmallVector<Value> dynamicSizes;
2806 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2807
2808 // Iterator types for the linalg.generic implementation
2809 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2810 utils::IteratorType::parallel, utils::IteratorType::parallel,
2811 utils::IteratorType::parallel, utils::IteratorType::reduction,
2812 utils::IteratorType::reduction};
2813
2814 // Inputs/outputs to the linalg.generic implementation
2815 llvm::SmallVector<Value> genericOpInputs = {input};
2816 llvm::SmallVector<Value> genericOpOutputs = {
2817 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2818 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2819
2820 // Indexing maps for input and output tensors
2821 auto indexingMaps = AffineMap::inferFromExprList(
2822 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2823 affineDimsExpr(rewriter, 0, 1, 2),
2824 affineDimsExpr(rewriter, 0, 1, 2)},
2825 rewriter.getContext());
2826
2827 // Width and height dimensions of the original input.
2828 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2829 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2830
2831 // Constants and dimension sizes
2832 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2833 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2834 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2835 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2836
2837 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2838 Value valReal = args[0];
2839 Value sumReal = args[1];
2840 Value sumImag = args[2];
2841
2842 // Indices for angle computation
2843 Value oy = linalg::IndexOp::create(builder, loc, 1);
2844 Value ox = linalg::IndexOp::create(builder, loc, 2);
2845 Value iy = linalg::IndexOp::create(builder, loc, 3);
2846 Value ix = linalg::IndexOp::create(builder, loc, 4);
2847
2848 // Calculating angle without integer parts of components as sin/cos are
2849 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2850 // / W);
2851 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2852 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2853
2854 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2855 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2856
2857 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2858 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2859
2860 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2861 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2862 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2863 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2864
2865 // realComponent = valReal * cos(angle)
2866 // imagComponent = valReal * sin(angle)
2867 auto cosAngle = math::CosOp::create(builder, loc, angle);
2868 auto sinAngle = math::SinOp::create(builder, loc, angle);
2869 auto realComponent =
2870 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2871 auto imagComponent =
2872 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2873
2874 // outReal = sumReal + realComponent
2875 // outImag = sumImag - imagComponent
2876 auto outReal =
2877 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2878 auto outImag =
2879 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2880
2881 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2882 };
2883
2884 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2885 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2886 indexingMaps, iteratorTypes, buildBody);
2887
2888 return success();
2889 }
2890};
2891
2892struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2894
2895 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2896 PatternRewriter &rewriter) const override {
2897 if (!llvm::all_of(fft2d->getOperandTypes(),
2898 RFFT2dConverter::isRankedTensor) ||
2899 !llvm::all_of(fft2d->getResultTypes(),
2900 RFFT2dConverter::isRankedTensor)) {
2901 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2902 }
2903
2904 Location loc = fft2d.getLoc();
2905 Value input_real = fft2d.getInputReal();
2906 Value input_imag = fft2d.getInputImag();
2907 BoolAttr inverse = fft2d.getInverseAttr();
2908
2909 auto real_el_ty = cast<FloatType>(
2910 cast<ShapedType>(input_real.getType()).getElementType());
2911 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2912 cast<ShapedType>(input_imag.getType()).getElementType());
2913
2914 assert(real_el_ty == imag_el_ty);
2915
2916 // Compute the output type and set of dynamic sizes
2917 SmallVector<Value> dynamicSizes;
2918
2919 // Get [N, H, W]
2920 auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2921
2922 SmallVector<int64_t, 3> staticSizes;
2923 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2924
2925 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2926
2927 // Iterator types for the linalg.generic implementation
2928 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2929 utils::IteratorType::parallel, utils::IteratorType::parallel,
2930 utils::IteratorType::parallel, utils::IteratorType::reduction,
2931 utils::IteratorType::reduction};
2932
2933 // Inputs/outputs to the linalg.generic implementation
2934 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2935 SmallVector<Value> genericOpOutputs = {
2936 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2937 dynamicSizes),
2938 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2939 dynamicSizes)};
2940
2941 // Indexing maps for input and output tensors
2942 auto indexingMaps = AffineMap::inferFromExprList(
2943 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2944 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2945 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2946 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2947 rewriter.getContext());
2948
2949 // Width and height dimensions of the original input.
2950 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2951 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2952
2953 // Constants and dimension sizes
2954 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2955 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2956 Value constH =
2957 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2958 Value constW =
2959 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2960
2961 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2962 Value valReal = args[0];
2963 Value valImag = args[1];
2964 Value sumReal = args[2];
2965 Value sumImag = args[3];
2966
2967 // Indices for angle computation
2968 Value oy = linalg::IndexOp::create(builder, loc, 1);
2969 Value ox = linalg::IndexOp::create(builder, loc, 2);
2970 Value iy = linalg::IndexOp::create(builder, loc, 3);
2971 Value ix = linalg::IndexOp::create(builder, loc, 4);
2972
2973 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2974 // ox) % W ) / W);
2975 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2976 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2977
2978 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2979 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2980
2981 auto iyRemFloat =
2982 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2983 auto ixRemFloat =
2984 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2985
2986 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2987 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2988
2989 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2990 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2991
2992 if (inverse.getValue()) {
2993 angle = arith::MulFOp::create(
2994 builder, loc, angle,
2995 arith::ConstantOp::create(rewriter, loc,
2996 rewriter.getFloatAttr(real_el_ty, -1.0)));
2997 }
2998
2999 // realComponent = val_real * cos(a) + val_imag * sin(a);
3000 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
3001 auto cosAngle = math::CosOp::create(builder, loc, angle);
3002 auto sinAngle = math::SinOp::create(builder, loc, angle);
3003
3004 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3005 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3006 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3007
3008 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3009 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3010
3011 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3012
3013 // outReal = sumReal + realComponent
3014 // outImag = sumImag - imagComponent
3015 auto outReal =
3016 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3017 auto outImag =
3018 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3019
3020 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
3021 };
3022
3023 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
3024 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3025 indexingMaps, iteratorTypes, buildBody);
3026
3027 return success();
3028 }
3029};
3030
3031} // namespace
3032
3034 const TypeConverter &converter, RewritePatternSet *patterns) {
3035
3036 // We have multiple resize coverters to handle degenerate cases.
3037 patterns->add<GenericResizeConverter>(patterns->getContext(),
3038 /*benefit=*/100);
3039 patterns->add<ResizeUnaryConverter>(patterns->getContext(),
3040 /*benefit=*/200);
3041 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
3042 /*benefit=*/300);
3043
3044 patterns->add<
3045 // clang-format off
3046 PointwiseConverter<tosa::AddOp>,
3047 PointwiseConverter<tosa::SubOp>,
3048 PointwiseConverter<tosa::MulOp>,
3049 PointwiseConverter<tosa::IntDivOp>,
3050 PointwiseConverter<tosa::NegateOp>,
3051 PointwiseConverter<tosa::PowOp>,
3052 PointwiseConverter<tosa::ReciprocalOp>,
3053 PointwiseConverter<tosa::RsqrtOp>,
3054 PointwiseConverter<tosa::LogOp>,
3055 PointwiseConverter<tosa::ExpOp>,
3056 PointwiseConverter<tosa::AbsOp>,
3057 PointwiseConverter<tosa::SinOp>,
3058 PointwiseConverter<tosa::CosOp>,
3059 PointwiseConverter<tosa::TanhOp>,
3060 PointwiseConverter<tosa::ErfOp>,
3061 PointwiseConverter<tosa::BitwiseAndOp>,
3062 PointwiseConverter<tosa::BitwiseOrOp>,
3063 PointwiseConverter<tosa::BitwiseNotOp>,
3064 PointwiseConverter<tosa::BitwiseXorOp>,
3065 PointwiseConverter<tosa::LogicalAndOp>,
3066 PointwiseConverter<tosa::LogicalNotOp>,
3067 PointwiseConverter<tosa::LogicalOrOp>,
3068 PointwiseConverter<tosa::LogicalXorOp>,
3069 PointwiseConverter<tosa::CastOp>,
3070 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3071 PointwiseConverter<tosa::LogicalRightShiftOp>,
3072 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3073 PointwiseConverter<tosa::ClzOp>,
3074 PointwiseConverter<tosa::SelectOp>,
3075 PointwiseConverter<tosa::GreaterOp>,
3076 PointwiseConverter<tosa::GreaterEqualOp>,
3077 PointwiseConverter<tosa::EqualOp>,
3078 PointwiseConverter<tosa::MaximumOp>,
3079 PointwiseConverter<tosa::MinimumOp>,
3080 PointwiseConverter<tosa::CeilOp>,
3081 PointwiseConverter<tosa::FloorOp>,
3082 PointwiseConverter<tosa::ClampOp>,
3083 PointwiseConverter<tosa::SigmoidOp>
3084 >(converter, patterns->getContext());
3085
3086 patterns->add<
3087 IdentityNConverter<tosa::IdentityOp>,
3088 ReduceConverter<tosa::ReduceAllOp>,
3089 ReduceConverter<tosa::ReduceAnyOp>,
3090 ReduceConverter<tosa::ReduceMinOp>,
3091 ReduceConverter<tosa::ReduceMaxOp>,
3092 ReduceConverter<tosa::ReduceSumOp>,
3093 ReduceConverter<tosa::ReduceProductOp>,
3094 ArgMaxConverter,
3095 GatherConverter,
3096 RescaleConverter,
3097 ReverseConverter,
3098 RFFT2dConverter,
3099 FFT2dConverter,
3100 TableConverter,
3101 TileConverter>(patterns->getContext());
3102 // clang-format on
3103}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
Definition Block.h:129
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
IntegerAttr getI8IntegerAttr(int8_t value)
Definition Builders.cpp:221
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...