MLIR 23.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32
33#include <type_traits>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38// Helper function to materialize the semantically correct compare and select
39// operations given a binary operation with a specific NaN propagation mode.
40//
41// In the case of "PROPAGATE" semantics no compare and selection is required and
42// this function does nothing.
43//
44// In the case of "IGNORE" semantics this function materializes a comparison of
45// the current operands to the op which will return true for any NaN
46// argument and then selects between the non-NaN operation argument and the
47// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
48// code:
49//
50// In the case that the op is operating on non floating point types we ignore
51// the attribute completely, this is consistent with the TOSA spec which has
52// the following wording: "This attribute is ignored by non floating-point
53// types."
54//
55// binary<op>(lhs, rhs):
56// result = op(lhs, rhs)
57// if lhs == NaN return rhs
58// if rhs == NaN return lhs
59// return result
60template <typename OpTy>
61static Value
64 // NaN propagation has no meaning for non floating point types.
65 if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
66 return result;
67
68 auto nanMode = op.getNanMode();
69 if (nanMode == NanPropagationMode::PROPAGATE)
70 return result;
71
72 // Unordered comparison of NaN against itself will always return true.
73 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
74 arith::CmpFPredicate::UNO, lhs, lhs);
75 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
76 arith::CmpFPredicate::UNO, rhs, rhs);
77 Value rhsOrResult =
78 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
79 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
80 rhsOrResult);
81}
82
84 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
85 ConversionPatternRewriter &rewriter) {
86 Location loc = op->getLoc();
87 auto elementTy =
88 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
89
90 // tosa::AbsOp
91 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
92 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93
94 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
95 auto zero = arith::ConstantOp::create(rewriter, loc,
96 rewriter.getZeroAttr(elementTy));
97 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
98 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
99 }
100
101 // tosa::AddOp
102 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
103 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104
105 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
106 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
107
108 // tosa::SubOp
109 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
110 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111
112 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
113 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
114
115 // tosa::IntDivOp
116 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
117 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
118
119 // tosa::ReciprocalOp
120 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 auto one =
122 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
123 return arith::DivFOp::create(rewriter, loc, one, args[0]);
124 }
125
126 // tosa::MulOp
127 if (isa<tosa::MulOp>(op)) {
128 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 DenseElementsAttr shiftElem;
130 bool shiftIsConstant = true;
131 int32_t shift = 0;
132 if (matchPattern(shiftVal, m_Constant(&shiftElem)))
133 shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
134 else
135 shiftIsConstant = false;
136
137 if (isa<FloatType>(elementTy)) {
138 if (shift != 0) {
139 (void)rewriter.notifyMatchFailure(op,
140 "Cannot have shift value for float");
141 return nullptr;
142 }
143 return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
144 }
145
146 if (isa<IntegerType>(elementTy)) {
147 Value a = args[0];
148 Value b = args[1];
149
150 if (shift > 0 || !shiftIsConstant) {
151 Value shiftConst;
152 if (shiftIsConstant)
153 shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
154 /*bitwidth=*/8);
155
156 if (!a.getType().isInteger(32))
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
158
159 if (!b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
161
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
165 auto result =
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
168
169 return result;
170 }
171
172 int aWidth = a.getType().getIntOrFloatBitWidth();
173 int bWidth = b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
175
176 if (aWidth < cWidth)
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
178 if (bWidth < cWidth)
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
180
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
182 }
183 }
184
185 // tosa::NegateOp
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
188
189 int64_t inZp = 0, outZp = 0;
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
194 if (hasInZp)
195 inZp = *maybeInZp;
196 if (hasOutZp)
197 outZp = *maybeOutZp;
198
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
201
202 if (isa<IntegerType>(elementTy)) {
203 Value zpAddValue;
204 Type intermediateType;
205 // Compute the maximum value that can occur in the intermediate buffer.
206 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
207 int intermediateBitWidth = 64;
208
209 if (hasInZp && hasOutZp) {
210 // Compute the maximum value that can occur in the intermediate buffer.
211 const int64_t zpAdd = inZp + outZp;
212 const int64_t maxValue =
213 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
214 std::abs(zpAdd) + 1;
215
216 // Convert that maximum value into the maximum bitwidth needed to
217 // represent it.
218 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
219 intermediateBitWidth = 16;
220 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
221 intermediateBitWidth = 32;
222 }
223
224 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
225 zpAddValue = arith::ConstantOp::create(
226 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
227 } else {
228 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
229 Value arg1 = args[1];
230 Value arg2 = args[2];
231 // Avoid verifier-invalid no-op sign-extends; only widen when needed.
232 if (arg1.getType() != intermediateType)
233 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
234 if (arg2.getType() != intermediateType)
235 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
236 zpAddValue =
237 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
238 }
239
240 // The negation can be applied by doing:
241 // outputValue = inZp + outZp - inputValue
242 Value ext = args[0];
243 if (ext.getType() != intermediateType)
244 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
245 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
246
247 // Clamp to the negation range.
249 rewriter, loc, intermediateType,
250 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
252 rewriter, loc, intermediateType,
253 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
254 auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
255
256 // Truncate to the final value, skipping no-op trunci when widths match.
257 if (clamp.getType() == elementTy)
258 return clamp;
259 return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
260 }
261 }
262
263 // tosa::BitwiseAndOp
264 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
265 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
266
267 // tosa::BitwiseOrOp
268 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
269 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
270
271 // tosa::BitwiseNotOp
272 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
273 auto allOnesAttr = rewriter.getIntegerAttr(
274 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
275 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
276 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
277 }
278
279 // tosa::BitwiseXOrOp
280 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
281 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
282
283 // tosa::LogicalLeftShiftOp
284 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
285 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
286
287 // tosa::LogicalRightShiftOp
288 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
290
291 // tosa::ArithmeticRightShiftOp
292 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
293 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
294 auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
295 if (!round) {
296 return result;
297 }
298
299 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
300 auto one = arith::ConstantOp::create(rewriter, loc,
301 IntegerAttr::get(elementTy, 1));
302 auto zero = arith::ConstantOp::create(rewriter, loc,
303 IntegerAttr::get(elementTy, 0));
304 auto i1zero =
305 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
306 auto i1one =
307 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
308
309 // Checking that input2 != 0
310 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
311 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
312
313 // Checking for the last bit of input1 to be 1
314 auto subtract =
315 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
316 auto shifted =
317 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
318 ->getResults();
319 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
321 auto isInputOdd =
322 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
323 // shifted, truncated, isInputOdd can be poison when input2 is 0.
324 auto shouldRound = arith::SelectOp::create(
325 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
326 auto extended =
327 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
328 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
329 }
330
331 // tosa::ClzOp
332 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
333 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
334 }
335
336 // tosa::LogicalAnd
337 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
338 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
339
340 // tosa::LogicalNot
341 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
342 auto one = arith::ConstantOp::create(rewriter, loc,
343 rewriter.getIntegerAttr(elementTy, 1));
344 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
345 }
346
347 // tosa::LogicalOr
348 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
349 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
350
351 // tosa::LogicalXor
352 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
353 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
354
355 // tosa::PowOp
356 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
357 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
358
359 // tosa::RsqrtOp
360 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
362
363 // tosa::LogOp
364 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
366
367 // tosa::ExpOp
368 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
370
371 // tosa::SinOp
372 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
374
375 // tosa::CosOp
376 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
378
379 // tosa::TanhOp
380 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
382
383 // tosa::ErfOp
384 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
385 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
386
387 // tosa::GreaterOp
388 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
389 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
390 args[0], args[1]);
391
392 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
393 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
394 args[0], args[1]);
395
396 // tosa::GreaterEqualOp
397 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
398 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
399 args[0], args[1]);
400
401 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
402 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
403 args[0], args[1]);
404
405 // tosa::EqualOp
406 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
407 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
408 args[0], args[1]);
409
410 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
411 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
412 args[0], args[1]);
413
414 // tosa::SelectOp
415 if (isa<tosa::SelectOp>(op)) {
416 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
417 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
418 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
419 }
420
421 // tosa::MaximumOp
422 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
423 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
424 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
425 rewriter, args[0], args[1], max);
426 }
427
428 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
429 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
430 }
431
432 // tosa::MinimumOp
433 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
434 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
435 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
436 rewriter, args[0], args[1], min);
437 }
438
439 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
440 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
441 }
442
443 // tosa::CeilOp
444 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
445 return math::CeilOp::create(rewriter, loc, resultTypes, args);
446
447 // tosa::FloorOp
448 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
449 return math::FloorOp::create(rewriter, loc, resultTypes, args);
450
451 // tosa::ClampOp
452 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
453 bool losesInfo = false;
454 APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
455 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
456 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
457 APFloat::rmNearestTiesToEven, &losesInfo);
458 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
459 APFloat::rmNearestTiesToEven, &losesInfo);
460 auto min = arith::ConstantOp::create(
461 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
462 auto max = arith::ConstantOp::create(
463 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
464 auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
465
466 auto clampOp = llvm::cast<tosa::ClampOp>(op);
467 const auto nanMode = clampOp.getNanMode();
468
469 // NaN propagation has no meaning for non floating point types.
470 if (!isa<FloatType>(elementTy))
471 return result;
472
473 // In the case of "PROPAGATE" semantics no compare and selection is
474 // required.
475 if (nanMode == NanPropagationMode::PROPAGATE)
476 return result;
477
478 // In the case of "IGNORE" semantics materialize a comparison
479 // of the current operand to the reduction which will return true for a NaN
480 // argument and then selects between the initial reduction value and the
481 // calculated result based on whether the argument is NaN or not. In pseudo
482 // code:
483 //
484 // reduce<op>(x, init):
485 // result = op(init, x)
486 // return init if x == NaN else result
487
488 // Unordered comparison of NaN against itself will always return true.
489 Value isNaN = arith::CmpFOp::create(
490 rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
491 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
492 // is NaN.
493 return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
494 }
495
496 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
497 auto intTy = cast<IntegerType>(elementTy);
498 int64_t min =
499 cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
500 int64_t max =
501 cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
502
503 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
504 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
505 if (intTy.isUnsignedInteger()) {
506 minRepresentable = 0;
507 if (intTy.getIntOrFloatBitWidth() <= 63) {
508 maxRepresentable =
509 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
510 .getZExtValue();
511 }
512 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
513 // Ensure that min & max fit into signed n-bit constants.
514 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
515 .getSExtValue();
516 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
517 .getSExtValue();
518 }
519 // Ensure that the bounds are representable as n-bit signed/unsigned
520 // integers.
521 min = std::max(min, minRepresentable);
522 max = std::max(max, minRepresentable);
523 min = std::min(min, maxRepresentable);
524 max = std::min(max, maxRepresentable);
525
526 auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
527 intTy.getIntOrFloatBitWidth());
528 auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
529 intTy.getIntOrFloatBitWidth());
530 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
531 intTy.isUnsignedInteger());
532 }
533
534 // tosa::SigmoidOp
535 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
536 auto one =
537 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
538 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
539 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
540 auto added = arith::AddFOp::create(rewriter, loc, exp, one);
541 return arith::DivFOp::create(rewriter, loc, one, added);
542 }
543
544 // tosa::CastOp
545 if (isa<tosa::CastOp>(op)) {
546 Type srcTy = elementTy;
547 Type dstTy = resultTypes.front();
548 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
549 (void)rewriter.notifyMatchFailure(op, "unsupported type");
550 return nullptr;
551 }
552
553 bool bitExtend =
555
556 if (srcTy == dstTy)
557 return args.front();
558
559 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
560 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
562
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
564 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
566
567 // 1-bit integers need to be treated as signless.
568 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
569 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
571
572 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
573 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
575
576 // Unsigned integers need an unrealized cast so that they can be passed
577 // to UIToFP.
578 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
579 auto unrealizedCast =
580 UnrealizedConversionCastOp::create(
581 rewriter, loc,
582 rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
583 .getResult(0);
584 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
585 unrealizedCast);
586 }
587
588 // All other si-to-fp conversions should be handled by SIToFP.
589 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
590 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
592
593 // Casting to boolean, floats need to only be checked as not-equal to zero.
594 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
595 Value zero = arith::ConstantOp::create(rewriter, loc,
596 rewriter.getFloatAttr(srcTy, 0.0));
597 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
598 args.front(), zero);
599 }
600
601 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
602 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
603
604 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
605 // Check whether neither int min nor int max can be represented in the
606 // input floating-point type due to too short exponent range.
607 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
608 APFloat::semanticsMaxExponent(fltSemantics)) {
609 // Use cmp + select to replace infinites by int min / int max. Other
610 // integral values can be represented in the integer space.
611 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
612 auto posInf = arith::ConstantOp::create(
613 rewriter, loc,
614 rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
615 APFloat::getInf(fltSemantics)));
616 auto negInf = arith::ConstantOp::create(
617 rewriter, loc,
618 rewriter.getFloatAttr(
620 APFloat::getInf(fltSemantics, /*Negative=*/true)));
621 auto overflow = arith::CmpFOp::create(
622 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
623 auto underflow = arith::CmpFOp::create(
624 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
625 auto intMin = arith::ConstantOp::create(
626 rewriter, loc,
627 rewriter.getIntegerAttr(
629 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
630 auto intMax = arith::ConstantOp::create(
631 rewriter, loc,
632 rewriter.getIntegerAttr(
634 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
635 auto maxClamped =
636 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
637 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
638 maxClamped);
639 }
640
641 auto intMinFP = arith::ConstantOp::create(
642 rewriter, loc,
643 rewriter.getFloatAttr(
645 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
646 .getSExtValue()));
647
648 // Check whether the mantissa has enough bits to represent int max.
649 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
650 dstTy.getIntOrFloatBitWidth() - 1) {
651 // Int min can also be represented since it is a power of two and thus
652 // consists of a single leading bit. Therefore we can clamp the input
653 // in the floating-point domain.
654
655 auto intMaxFP = arith::ConstantOp::create(
656 rewriter, loc,
657 rewriter.getFloatAttr(
659 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
660 .getSExtValue()));
661
662 Value clamped =
663 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
664 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
665 }
666
667 // Due to earlier check we know exponant range is big enough to represent
668 // int min. We can therefore rely on int max + 1 being representable as
669 // well because it's just int min with a positive sign. So clamp the min
670 // value and compare against that to select the max int value if needed.
671 auto intMaxPlusOneFP = arith::ConstantOp::create(
672 rewriter, loc,
673 rewriter.getFloatAttr(
675 static_cast<double>(
676 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
677 .getSExtValue()) +
678 1.0f));
679
680 auto intMax = arith::ConstantOp::create(
681 rewriter, loc,
682 rewriter.getIntegerAttr(
684 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
685 auto minClampedFP =
686 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
687 auto minClamped =
688 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
689 auto overflow = arith::CmpFOp::create(
690 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
691 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
692 minClamped);
693 }
694
695 // Casting to boolean, integers need to only be checked as not-equal to
696 // zero.
697 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
698 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
699 srcTy.getIntOrFloatBitWidth());
700 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
701 args.front(), zero);
702 }
703
704 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
705 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
707
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
709 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
710 }
711 }
712
713 (void)rewriter.notifyMatchFailure(
714 op, "unhandled op for linalg body calculation for elementwise op");
715 return nullptr;
716}
717
719
720// Emit an 'arith.constant' op for the given index if it has not been created
721// yet, or return an existing constant. This will prevent an excessive creation
722// of redundant constants, easing readability of emitted code for unit tests.
724 IndexPool &indexPool, int64_t index) {
725 auto [it, inserted] = indexPool.try_emplace(index);
726 if (inserted)
727 it->second =
728 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
729 return it->second;
730}
731
733 IndexPool &indexPool, Value tensor, int64_t index) {
734 auto indexValue = createIndex(rewriter, loc, indexPool, index);
735 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
736}
737
739 IndexPool &indexPool, Value tensor,
740 int64_t index) {
741 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
742 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
743 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
744 if (shapedType.isDynamicDim(index))
745 return getTensorDim(rewriter, loc, indexPool, tensor, index);
746 return rewriter.getIndexAttr(shapedType.getDimSize(index));
747}
748
749static bool operandsAndResultsRanked(Operation *operation) {
750 auto isRanked = [](Value value) {
751 return isa<RankedTensorType>(value.getType());
752 };
753 return llvm::all_of(operation->getOperands(), isRanked) &&
754 llvm::all_of(operation->getResults(), isRanked);
755}
756
757// Compute the runtime dimension size for dimension 'dim' of the output by
758// inspecting input 'operands', all of which are expected to have the same rank.
759// This function returns a pair {targetSize, masterOperand}.
760//
761// The runtime size of the output dimension is returned either as a statically
762// computed attribute or as a runtime SSA value.
763//
764// If the target size was inferred directly from one dominating operand, that
765// operand is returned in 'masterOperand'. If the target size is inferred from
766// multiple operands, 'masterOperand' is set to nullptr.
767static std::pair<OpFoldResult, Value>
769 ValueRange operands, int64_t dim) {
770 // If any input operand contains a static size greater than 1 for this
771 // dimension, that is the target size. An occurrence of an additional static
772 // dimension greater than 1 with a different value is undefined behavior.
773 for (auto operand : operands) {
774 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
775 if (ShapedType::isStatic(size) && size > 1)
776 return {rewriter.getIndexAttr(size), operand};
777 }
778
779 // Filter operands with dynamic dimension
780 auto operandsWithDynamicDim =
781 llvm::filter_to_vector(operands, [&](Value operand) {
782 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
783 });
784
785 // If no operand has a dynamic dimension, it means all sizes were 1
786 if (operandsWithDynamicDim.empty())
787 return {rewriter.getIndexAttr(1), operands.front()};
788
789 // Emit code that computes the runtime size for this dimension. If there is
790 // only one operand with a dynamic dimension, it is considered the master
791 // operand that determines the runtime size of the output dimension.
792 auto targetSize =
793 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
794 if (operandsWithDynamicDim.size() == 1)
795 return {targetSize, operandsWithDynamicDim[0]};
796
797 // Calculate maximum size among all dynamic dimensions
798 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
799 auto nextSize =
800 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
801 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
802 }
803 return {targetSize, nullptr};
804}
805
806// Compute the runtime output size for all dimensions. This function returns
807// a pair {targetShape, masterOperands}.
808static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
810 IndexPool &indexPool, ValueRange operands) {
811 assert(!operands.empty());
812 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
813 SmallVector<OpFoldResult> targetShape;
814 SmallVector<Value> masterOperands;
815 for (auto dim : llvm::seq<int64_t>(0, rank)) {
816 auto [targetSize, masterOperand] =
817 computeTargetSize(rewriter, loc, indexPool, operands, dim);
818 targetShape.push_back(targetSize);
819 masterOperands.push_back(masterOperand);
820 }
821 return {targetShape, masterOperands};
822}
823
825 IndexPool &indexPool, Value operand,
826 int64_t dim, OpFoldResult targetSize,
827 Value masterOperand) {
828 // Nothing to do if this is a static dimension
829 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
830 if (!rankedTensorType.isDynamicDim(dim))
831 return operand;
832
833 // If the target size for this dimension was directly inferred by only taking
834 // this operand into account, there is no need to broadcast. This is an
835 // optimization that will prevent redundant control flow, and constitutes the
836 // main motivation for tracking "master operands".
837 if (operand == masterOperand)
838 return operand;
839
840 // Affine maps for 'linalg.generic' op
841 auto rank = rankedTensorType.getRank();
842 SmallVector<AffineExpr> affineExprs;
843 for (auto index : llvm::seq<int64_t>(0, rank)) {
844 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
845 : rewriter.getAffineDimExpr(index);
846 affineExprs.push_back(affineExpr);
847 }
848 auto broadcastAffineMap =
849 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
850 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
851 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
852
853 // Check if broadcast is necessary
854 auto one = createIndex(rewriter, loc, indexPool, 1);
855 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
856 auto broadcastNecessary = arith::CmpIOp::create(
857 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
858
859 // Emit 'then' region of 'scf.if'
860 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
861 // It is not safe to cache constants across regions.
862 // New constants could potentially violate dominance requirements.
863 IndexPool localPool;
864
865 // Emit 'tensor.empty' op
866 SmallVector<OpFoldResult> outputTensorShape;
867 for (auto index : llvm::seq<int64_t>(0, rank)) {
868 auto size = index == dim ? targetSize
869 : getOrFoldTensorDim(rewriter, loc, localPool,
870 operand, index);
871 outputTensorShape.push_back(size);
872 }
873 Value outputTensor = tensor::EmptyOp::create(
874 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
875
876 // Emit 'linalg.generic' op
877 auto resultTensor =
878 linalg::GenericOp::create(
879 opBuilder, loc, outputTensor.getType(), operand, outputTensor,
880 affineMaps, getNParallelLoopsAttrs(rank),
881 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
882 // Emit 'linalg.yield' op
883 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
884 })
885 .getResult(0);
886
887 // Cast to original operand type if necessary
888 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
889 loc, operand.getType(), resultTensor);
890
891 // Emit 'scf.yield' op
892 scf::YieldOp::create(opBuilder, loc, castResultTensor);
893 };
894
895 // Emit 'else' region of 'scf.if'
896 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
897 scf::YieldOp::create(opBuilder, loc, operand);
898 };
899
900 // Emit 'scf.if' op
901 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
902 emitThenRegion, emitElseRegion);
903 return ifOp.getResult(0);
904}
905
907 IndexPool &indexPool, Value operand,
908 ArrayRef<OpFoldResult> targetShape,
909 ArrayRef<Value> masterOperands) {
910 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
911 assert((int64_t)targetShape.size() == rank);
912 assert((int64_t)masterOperands.size() == rank);
913 for (auto index : llvm::seq<int64_t>(0, rank))
914 operand =
915 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
916 targetShape[index], masterOperands[index]);
917 return operand;
918}
919
922 IndexPool &indexPool, ValueRange operands,
923 ArrayRef<OpFoldResult> targetShape,
924 ArrayRef<Value> masterOperands) {
925 // No need to broadcast for unary operations
926 if (operands.size() == 1)
927 return operands;
928
929 // No need to broadcast for static shape
930 bool hasDynamic = false;
931 for (auto op : operands) {
932 const auto tType = dyn_cast<RankedTensorType>(op.getType());
933 if (tType && !tType.hasStaticShape()) {
934 hasDynamic = true;
935 break;
936 }
937 }
938 if (!hasDynamic)
939 return operands;
940
941 // Broadcast dynamic dimensions operand by operand
942 return llvm::map_to_vector(operands, [&](Value operand) {
943 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
944 targetShape, masterOperands);
945 });
946}
947
948static LogicalResult
949emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
950 Operation *operation, ValueRange operands,
951 ArrayRef<OpFoldResult> targetShape,
952 const TypeConverter &converter) {
953 // Generate output tensor
954 auto resultType = cast_or_null<RankedTensorType>(
955 converter.convertType(operation->getResultTypes().front()));
956 if (!resultType) {
957 return rewriter.notifyMatchFailure(operation, "failed to convert type");
958 }
959 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
960 resultType.getElementType());
961
962 // Create affine maps. Input affine maps broadcast static dimensions of size
963 // 1. The output affine map is an identity map.
964 //
965 auto rank = resultType.getRank();
966 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
967 auto shape = cast<ShapedType>(operand.getType()).getShape();
968 SmallVector<AffineExpr> affineExprs;
969 for (auto it : llvm::enumerate(shape)) {
970 // Prefer producting identity maps whenever possible (i.e. no broadcasting
971 // needed) because some transforms (like reshape folding)
972 // do not support affine constant exprs.
973 bool requiresBroadcast =
974 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
975 auto affineExpr = requiresBroadcast
976 ? rewriter.getAffineConstantExpr(0)
977 : rewriter.getAffineDimExpr(it.index());
978 affineExprs.push_back(affineExpr);
979 }
980 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
981 });
982 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
983
984 // Emit 'linalg.generic' op
985 bool encounteredError = false;
986 auto linalgOp = linalg::GenericOp::create(
987 rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
989 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
991 operation, blockArgs.take_front(operation->getNumOperands()),
992 {resultType.getElementType()}, rewriter);
993 if (!opResult) {
994 encounteredError = true;
995 return;
996 }
997 linalg::YieldOp::create(opBuilder, loc, opResult);
998 });
999 if (encounteredError)
1000 return rewriter.notifyMatchFailure(
1001 operation, "unable to create linalg.generic body for elementwise op");
1002
1003 // Cast 'linalg.generic' result into original result type if needed
1004 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1005 loc, resultType, linalgOp->getResult(0));
1006 rewriter.replaceOp(operation, castResult);
1007 return success();
1008}
1009
1011 ValueRange operands) {
1012 // Shift cannot broadcast
1013 if (isa<tosa::MulOp>(operation)) {
1014 DenseElementsAttr shiftElems;
1015 // Shift cannot broadcast when it is constant
1016 if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1017 return operands.take_front(2);
1018 else
1019 return operands.take_front(3);
1020 }
1021 if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1022 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1023 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1024 if (failed(maybeOutZp) && failed(maybeInZp))
1025 return operands;
1026 // Input1_zp and output_zp cannot broadcast when they are constants.
1027 return operands.take_front(1);
1028 }
1029 return operands;
1030}
1031
1032static LogicalResult
1034 ConversionPatternRewriter &rewriter,
1035 const TypeConverter &converter) {
1036
1037 // Collect op properties
1038 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1039 assert(operation->getNumOperands() >= 1 &&
1040 "elementwise op expects at least 1 operand");
1041 if (!operandsAndResultsRanked(operation))
1042 return rewriter.notifyMatchFailure(operation,
1043 "Unranked tensors not supported");
1044
1045 // Lower operation
1046 IndexPool indexPool;
1047 auto loc = operation->getLoc();
1048 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1049 auto [targetShape, masterOperands] =
1050 computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1051 auto broadcastOperands =
1052 broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1053 targetShape, masterOperands);
1054 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1055 targetShape, converter);
1056}
1057
1058// Returns the constant initial value for a given reduction operation. The
1059// attribute type varies depending on the element type required.
1060static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1061 PatternRewriter &rewriter) {
1062 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1063 return rewriter.getFloatAttr(elementTy, 0.0);
1064
1065 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1066 return rewriter.getIntegerAttr(elementTy, 0);
1067
1068 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1069 return rewriter.getFloatAttr(elementTy, 1.0);
1070
1071 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1072 return rewriter.getIntegerAttr(elementTy, 1);
1073
1074 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1075 return rewriter.getFloatAttr(
1076 elementTy, APFloat::getLargest(
1077 cast<FloatType>(elementTy).getFloatSemantics(), false));
1078
1079 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1080 return rewriter.getIntegerAttr(
1081 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1082
1083 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1084 return rewriter.getFloatAttr(
1085 elementTy, APFloat::getLargest(
1086 cast<FloatType>(elementTy).getFloatSemantics(), true));
1087
1088 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1089 return rewriter.getIntegerAttr(
1090 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1091
1092 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1093 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1094
1095 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1096 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1097
1098 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1099 return rewriter.getFloatAttr(
1100 elementTy, APFloat::getLargest(
1101 cast<FloatType>(elementTy).getFloatSemantics(), true));
1102
1103 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1104 return rewriter.getIntegerAttr(
1105 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1106
1107 return {};
1108}
1109
1110// Creates the body calculation for a reduction. The operations vary depending
1111// on the input type.
1113 ValueRange args,
1114 Type elementTy,
1115 PatternRewriter &rewriter) {
1116 Location loc = op->getLoc();
1117 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1118 return arith::AddFOp::create(rewriter, loc, args);
1119 }
1120
1121 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1122 return arith::AddIOp::create(rewriter, loc, args);
1123 }
1124
1125 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1126 return arith::MulFOp::create(rewriter, loc, args);
1127 }
1128
1129 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1130 return arith::MulIOp::create(rewriter, loc, args);
1131 }
1132
1133 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1134 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1135 }
1136
1137 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1138 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1139 }
1140
1141 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1142 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1143 }
1144
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1146 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1147 }
1148
1149 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1150 return arith::AndIOp::create(rewriter, loc, args);
1151
1152 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1153 return arith::OrIOp::create(rewriter, loc, args);
1154
1155 return {};
1156}
1157
1158// Performs the match and rewrite for reduction operations. This includes
1159// declaring a correctly sized initial value, and the linalg.generic operation
1160// that reduces across the specified axis.
1161template <typename OpTy>
1162static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1163 PatternRewriter &rewriter) {
1164 auto loc = op->getLoc();
1165 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1166 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1167 if (!inputTy || !resultTy)
1168 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1169
1170 auto elementTy = resultTy.getElementType();
1171 Value input = op->getOperand(0);
1172
1173 // Figure out the accType if needed
1174 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1175 isa<FloatType>(elementTy) &&
1176 cast<FloatType>(elementTy).isBF16();
1177 Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1178
1179 SmallVector<int64_t> reduceShape;
1180 SmallVector<Value> dynDims;
1181 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1182 if (axis != i) {
1183 reduceShape.push_back(inputTy.getDimSize(i));
1184 if (inputTy.isDynamicDim(i))
1185 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1186 }
1187 }
1188
1189 SmallVector<Value> inputs, outputs;
1190 inputs.push_back(input);
1191
1192 // First fill the output buffer with the init value.
1193 auto emptyTensor =
1194 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1195 .getResult();
1196
1197 auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1198 if (!fillValueAttr)
1199 return rewriter.notifyMatchFailure(
1200 op, "No initial value found for reduction operation");
1201
1202 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1203 auto filledTensor =
1204 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1205 ValueRange{emptyTensor})
1206 .result();
1207 outputs.push_back(filledTensor);
1208
1209 bool isNanIgnoreMode = false;
1210 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1211 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1212 // NaN propagation has no meaning for non floating point types.
1213 if (isa<FloatType>(elementTy) &&
1214 op.getNanMode() == NanPropagationMode::IGNORE) {
1215 isNanIgnoreMode = true;
1216 // Because the TOSA spec requires the result be NaN iff all elements in
1217 // the reduction are NaN we can't simply perform a compare and select.
1218 // Additionally we have to keep track of whether we've seen any non-NaN
1219 // values and then do a final select based on this predicate.
1220 auto trueAttr = rewriter.getBoolAttr(true);
1221 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1222 auto emptyBoolTensor =
1223 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1224 trueValue.getType(), dynDims)
1225 .getResult();
1226 auto allResultsNaNTensor =
1227 linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1228 ValueRange{emptyBoolTensor})
1229 .result();
1230 // Note that because the linalg::ReduceOp has two variadic arguments
1231 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1232 // need to have the same number of inputs and outputs.
1233 //
1234 // The second input isn't actually used anywhere since the value used to
1235 // update the NaN flag is calculated inside the body of the reduction and
1236 // then used to update an out value.
1237 // In order to satisfy type constraints we just pass another copy of the
1238 // input here.
1239 inputs.push_back(input);
1240 outputs.push_back(allResultsNaNTensor);
1241 }
1242 }
1243
1244 bool didEncounterError = false;
1245 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1246 rewriter, loc, inputs, outputs, axis,
1247 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1248 std::array<Value, 2> binaryArgs{
1249 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1250
1251 // If reduction type differs then extend (applicable to reduce_sum)
1252 if (binaryArgs[0].getType() != accTy)
1253 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1254 binaryArgs[0]);
1255
1256 auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1257 accTy, rewriter);
1258 if (result)
1259 didEncounterError = true;
1260
1261 SmallVector<Value> resultsToYield;
1262 if (isNanIgnoreMode) {
1263 auto inputValue = blockArgs[0];
1264 auto initialValue = blockArgs[2];
1265 auto oldAllResultsNanFlagValue = blockArgs[3];
1266
1267 // Unordered comparison of NaN against itself will always return true.
1268 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1269 arith::CmpFPredicate::UNO,
1270 inputValue, inputValue);
1271 // If we've encountered a NaN, take the non-NaN value.
1272 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1273 isNaN, initialValue, result);
1274 // Update the flag which keeps track of whether we have seen a non-NaN
1275 // value.
1276 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1277 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1278 resultsToYield.push_back(selectOp);
1279 resultsToYield.push_back(newAllResultsNanFlagValue);
1280 } else {
1281 resultsToYield.push_back(result);
1282 }
1283 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1284 });
1285
1286 if (!didEncounterError)
1287 return rewriter.notifyMatchFailure(
1288 op, "unable to create linalg.generic body for reduce op");
1289
1290 if (isNanIgnoreMode) {
1291 // Materialize a check to see whether we encountered any non-NaN values, if
1292 // we didn't we need to select a tensor of NaNs since the result will just
1293 // be the initial identity value propagated through all the compares and
1294 // selects inside the reduction.
1295
1296 // Create a tensor full of NaNs.
1297 auto nanValueAttr = rewriter.getFloatAttr(
1298 accTy,
1299 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1300 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1301 auto emptyNanTensor =
1302 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1303 .getResult();
1304 auto nanFilledTensor =
1305 linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1306 ValueRange{emptyNanTensor})
1307 .result();
1308
1309 // Create an empty tensor, non need to fill this since it will be
1310 // overwritten by the select.
1311 auto finalEmptyTensor =
1312 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1313 .getResult();
1314
1315 // Do a selection between the tensors akin to:
1316 // result = NaN if "all results NaN" else result.
1317 SmallVector<Value> ins, outs;
1318 ins.push_back(linalgOp->getOpResult(1));
1319 ins.push_back(nanFilledTensor);
1320 ins.push_back(linalgOp->getResult(0));
1321 outs.push_back(finalEmptyTensor);
1322 auto linalgSelect =
1323 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1324 linalgOp = linalgSelect;
1325 }
1326
1327 // Truncate back to resultTy if needed
1328 Value reducedRes = linalgOp->getResult(0);
1329 if (widenAccTy) {
1330 auto resEmptyOp =
1331 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1332 .getResult();
1333
1334 const unsigned reducedRank =
1335 cast<ShapedType>(reducedRes.getType()).getRank();
1336 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1337 reducedRes =
1338 linalg::GenericOp::create(
1339 rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1340 ValueRange{resEmptyOp},
1341 ArrayRef<AffineMap>{identityMap, identityMap},
1342 getNParallelLoopsAttrs(reducedRank),
1343 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1344 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1345 elementTy, args[0]);
1346 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1347 })
1348 .getResults()[0];
1349 }
1350
1351 SmallVector<ReassociationExprs, 4> reassociationMap;
1352 uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1353 reassociationMap.resize(expandInputRank);
1354
1355 for (uint64_t i = 0; i < expandInputRank; i++) {
1356 int32_t dimToPush = i > axis ? i + 1 : i;
1357 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1358 }
1359
1360 if (expandInputRank != 0) {
1361 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1362 reassociationMap[expandedDim].push_back(
1363 rewriter.getAffineDimExpr(expandedDim + 1));
1364 }
1365
1366 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1367 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1368 // not have access to such information. This matters when handling dynamically
1369 // sized tensors.
1370 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1371 reassociationMap);
1372 return success();
1373}
1374
1375namespace {
1376
1377template <typename SrcOp>
1378class PointwiseConverter : public OpConversionPattern<SrcOp> {
1379public:
1380 using OpConversionPattern<SrcOp>::OpConversionPattern;
1381 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1382
1383 LogicalResult
1384 matchAndRewrite(SrcOp op, OpAdaptor operands,
1385 ConversionPatternRewriter &rewriter) const final {
1387 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1388 }
1389};
1390
1391// Collapse tensor<1xiN> into tensor<iN>
1392// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1393static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1394 Location loc) {
1396 // Create the collapsed type
1397 auto inputType = cast<RankedTensorType>(input.getType());
1398 auto elemType = inputType.getElementType();
1399 auto collapsedType = RankedTensorType::get({}, elemType);
1400 // Emit the collapse op
1401 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1402 reassociation);
1403}
1404
1406convertToI8(const llvm::SmallVector<int32_t> &input) {
1408 output.reserve(input.size());
1409
1410 for (auto v : llvm::map_range(
1411 input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1412 output.push_back(v);
1413 }
1414 return output;
1415}
1416
1417// The shift or multiplier may be either constant or non-constant, depending on
1418// whether dynamic extension is enabled.
1419// - If the shift or multiplier is non-constant, add it as an input to
1420// linalg::GenericOp by:
1421// 1. Pushing it into 'genericInputs'.
1422// 2. Appending a corresponding affine map to 'indexingMaps'.
1423// - If the shift or multiplier is constant, set 'constant' instead.
1424static void setupLinalgGenericOpInputAndIndexingMap(
1426 SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1427 bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1428 bool isShift = false) {
1429
1430 auto loc = op.getLoc();
1431 auto inputTy = cast<ShapedType>(op.getInput().getType());
1432 unsigned rank = inputTy.getRank();
1433 SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1434
1435 if (isConstant) {
1436 // If we are rescaling per-channel then we need to store the
1437 // values in a buffer.
1438 if (values.size() == 1) {
1439 IntegerAttr intAttr = isShift
1440 ? rewriter.getI8IntegerAttr(values.front())
1441 : rewriter.getI32IntegerAttr(values.front());
1442 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1443 } else {
1444 auto elementType =
1445 isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1446 auto tensorType = RankedTensorType::get(
1447 {static_cast<int64_t>(values.size())}, elementType);
1448 DenseIntElementsAttr EltAttr;
1449 if (isShift)
1450 EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1451 else
1452 EltAttr = DenseIntElementsAttr::get(tensorType, values);
1453 genericInputs.push_back(
1454 arith::ConstantOp::create(rewriter, loc, EltAttr));
1455 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1456 /*symbolCount=*/0, exprs,
1457 rewriter.getContext()));
1458 }
1459 } else {
1460 // If we are not rescaling per-channel then we need to collapse 1xN to N
1461 // and push broadcastMap.
1462 auto operand = isShift ? op.getShift() : op.getMultiplier();
1463 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1464 if (tensorType && tensorType.hasStaticShape() &&
1465 tensorType.getShape()[0] == 1) {
1466 // broadcastMap = affine_map<(d0, d1) -> ()>
1467 // It would affect as broadcast for scalar values in linalg::GenericOp.
1468 AffineMap broadcastMap =
1469 AffineMap::get(rank, 0, {}, rewriter.getContext());
1470 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1471 indexingMaps.push_back(broadcastMap);
1472 } else {
1473 genericInputs.push_back(operand);
1474 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1475 /*symbolCount=*/0, exprs,
1476 rewriter.getContext()));
1477 }
1478 }
1479 arg = indexingMaps.size() - 1;
1480}
1481
1482// Return the extended Zp to be used in subsequent arithmetic operations.
1483static Value getExtendZp(OpBuilder &builder, Type valueTy,
1484 FailureOr<int64_t> maybeZp, Location loc,
1485 ValueRange blockArgs, int64_t zpArg,
1486 bool isOutputZp = false) {
1487 Value result;
1488 const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1489 const uint32_t attrBitwidth =
1490 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1491 auto extendType = builder.getIntegerType(attrBitwidth);
1492 // The Zp value can be either constant or non-constant, depending on
1493 // whether dynamic extension is enabled.
1494 // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1495 // be passed as an input to linalg::GenericOp.
1496 if (failed(maybeZp)) {
1497 result = blockArgs[zpArg];
1498 auto zpTy = result.getType();
1499 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1500 // For ExtUIOp, the input must be signless.
1501 // UnrealizedConversionCastOp will cast the input to signless type.
1502 if (zpTy.isUnsignedInteger()) {
1503 result =
1504 UnrealizedConversionCastOp::create(
1505 builder, loc,
1506 builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1507 .getResult(0);
1508 }
1509 if (zpTy.isUnsignedInteger()) {
1510 return arith::ExtUIOp::create(builder, loc, extendType, result);
1511 } else {
1512 return arith::ExtSIOp::create(builder, loc, extendType, result);
1513 }
1514 }
1515 } else {
1516 return arith::ConstantOp::create(builder, loc,
1517 IntegerAttr::get(extendType, *maybeZp));
1518 }
1519 return result;
1520}
1521
1522class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1523public:
1524 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1525
1526 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1527 PatternRewriter &rewriter) const final {
1528 auto loc = op.getLoc();
1529 auto input = op.getInput();
1530 auto inputTy = cast<ShapedType>(op.getInput().getType());
1531 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1532 unsigned rank = inputTy.getRank();
1533
1534 // This is an illegal configuration. terminate and log an error
1535 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1536 return rewriter.notifyMatchFailure(
1537 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1538 "currently supported");
1539 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1540 return rewriter.notifyMatchFailure(
1541 op, "tosa.rescale requires scale32 for double_round to be true");
1542
1543 if (!isa<IntegerType>(inputTy.getElementType()))
1544 return rewriter.notifyMatchFailure(op, "only support integer type");
1545
1546 SmallVector<Value> dynDims;
1547 for (int i = 0; i < outputTy.getRank(); i++) {
1548 if (outputTy.isDynamicDim(i)) {
1549 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1550 }
1551 }
1552
1553 DenseElementsAttr shiftElems;
1554 bool isShiftConstant = false;
1555 if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1556 isShiftConstant = true;
1557
1558 DenseElementsAttr multiplierElems;
1559 bool isMultiplierConstant = false;
1560 if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1561 isMultiplierConstant = true;
1562
1563 llvm::SmallVector<int32_t> shiftValues;
1564 llvm::SmallVector<int32_t> multiplierValues;
1565 bool doubleRound;
1566
1567 if (isMultiplierConstant && isShiftConstant) {
1568 // explicit cast is required here
1569 shiftValues = llvm::map_to_vector(
1570 shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1571 return static_cast<int32_t>(attr.getInt());
1572 });
1573 multiplierValues =
1574 llvm::map_to_vector(multiplierElems.getValues<IntegerAttr>(),
1575 [](IntegerAttr attr) -> int32_t {
1576 return static_cast<int32_t>(attr.getInt());
1577 });
1578
1579 // If we shift by more than the bitwidth, this just sets to 0.
1580 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1581 if (shiftValues[i] > 63) {
1582 shiftValues[i] = 0;
1583 multiplierValues[i] = 0;
1584 }
1585 }
1586 // Double round only occurs if shift is greater than 31, check that this
1587 // is ever true.
1588 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1589 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1590 } else
1591 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1592
1593 RoundingMode roundingMode =
1594 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1595
1596 SmallVector<AffineMap> indexingMaps = {
1597 rewriter.getMultiDimIdentityMap(rank)};
1598 SmallVector<Value, 4> genericInputs = {input};
1599
1600 // If we are rescaling per-channel then we need to store the multiplier
1601 // values in a buffer.
1602 Value multiplierConstant;
1603 int64_t multiplierArg = 0;
1604 setupLinalgGenericOpInputAndIndexingMap(
1605 rewriter, multiplierValues, genericInputs, indexingMaps,
1606 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1607
1608 // If we are rescaling per-channel then we need to store the shift
1609 // values in a buffer.
1610 Value shiftConstant;
1611 int64_t shiftArg = 0;
1612 setupLinalgGenericOpInputAndIndexingMap(
1613 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1614 shiftConstant, shiftArg, true);
1615
1616 // broadcastMap = affine_map<(d0, d1) -> ()>
1617 // It would affect as broadcast for scalar values in linalg::GenericOp.
1618 AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1619 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1620 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1621 // The inputZp and outputZp may be either constant or non-constant,
1622 // depending on whether dynamic extension is enabled.
1623 // - If the zp's are non-constant, add them as an inputs to
1624 // linalg::GenericOp by:
1625 // 1. Pushing it into 'genericInputs'.
1626 // 2. Appending a corresponding affine map to 'indexingMaps'.
1627 // - If the zp's are constant, they would be generated as arith.constant.
1628 int64_t iZpArg = 0;
1629 if (failed(maybeIZp)) {
1630 genericInputs.push_back(
1631 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1632 indexingMaps.push_back(broadcastMap);
1633 iZpArg = indexingMaps.size() - 1;
1634 }
1635 int64_t oZpArg = 0;
1636 if (failed(maybeOZp)) {
1637 genericInputs.push_back(
1638 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1639 indexingMaps.push_back(broadcastMap);
1640 oZpArg = indexingMaps.size() - 1;
1641 }
1642
1643 // Indexing maps for output values.
1644 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1645
1646 // Construct the indexing maps needed for linalg.generic ops.
1647 Value emptyTensor = tensor::EmptyOp::create(
1648 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1649 ArrayRef<Value>({dynDims}));
1650
1651 auto linalgOp = linalg::GenericOp::create(
1652 rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1653 indexingMaps, getNParallelLoopsAttrs(rank),
1654 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1655 ValueRange blockArgs) {
1656 Value value = blockArgs[0];
1657 Type valueTy = value.getType();
1658
1659 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1660 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1661 nestedLoc, blockArgs, iZpArg);
1662
1663 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1664 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1665 nestedLoc, blockArgs, oZpArg, true);
1666
1667 IntegerType outIntType =
1668 cast<IntegerType>(blockArgs.back().getType());
1669 unsigned outBitWidth = outIntType.getWidth();
1670 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1671
1672 Value multiplier = multiplierConstant ? multiplierConstant
1673 : blockArgs[multiplierArg];
1674 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1675
1676 if (valueTy.isUnsignedInteger()) {
1677 value = UnrealizedConversionCastOp::create(
1678 nestedBuilder, nestedLoc,
1679 nestedBuilder.getIntegerType(
1680 valueTy.getIntOrFloatBitWidth()),
1681 value)
1682 .getResult(0);
1683 }
1684 if (valueTy.getIntOrFloatBitWidth() < 32) {
1685 if (op.getInputUnsigned()) {
1686 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1687 nestedBuilder.getI32Type(), value);
1688 } else {
1689 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1690 nestedBuilder.getI32Type(), value);
1691 }
1692 }
1693
1694 value =
1695 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1696
1697 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1698 nestedBuilder.getI32Type(), value,
1699 multiplier, shift, roundingMode);
1700
1701 // Move to the new zero-point.
1702 value =
1703 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1704
1705 // Saturate to the output size.
1706 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1707 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1708
1709 // Unsigned integers have a difference output value.
1710 if (op.getOutputUnsigned()) {
1711 intMin = 0;
1712 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1713 }
1714
1715 auto intMinVal = arith::ConstantOp::create(
1716 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1717 auto intMaxVal = arith::ConstantOp::create(
1718 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1719
1720 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1721 nestedBuilder, /*isUnsigned=*/false);
1722
1723 if (outIntType.getWidth() < 32) {
1724 value = arith::TruncIOp::create(
1725 nestedBuilder, nestedLoc,
1726 rewriter.getIntegerType(outIntType.getWidth()), value);
1727 }
1728
1729 if (outIntType.isUnsignedInteger()) {
1730 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1731 outIntType, value)
1732 .getResult(0);
1733 }
1734 linalg::YieldOp::create(nestedBuilder, loc, value);
1735 });
1736
1737 rewriter.replaceOp(op, linalgOp->getResults());
1738 return success();
1739 }
1740};
1741
1742// Handle the resize case where the input is a 1x1 image. This case
1743// can entirely avoiding having extract operations which target much
1744// more difficult to optimize away.
1745class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1746public:
1747 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1748
1749 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1750 PatternRewriter &rewriter) const final {
1751 Location loc = op.getLoc();
1752 ImplicitLocOpBuilder builder(loc, rewriter);
1753 auto input = op.getInput();
1754 auto inputTy = cast<RankedTensorType>(input.getType());
1755 auto resultTy = cast<RankedTensorType>(op.getType());
1756 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1757
1758 auto inputH = inputTy.getDimSize(1);
1759 auto inputW = inputTy.getDimSize(2);
1760 auto outputH = resultTy.getDimSize(1);
1761 auto outputW = resultTy.getDimSize(2);
1762
1763 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1764 return rewriter.notifyMatchFailure(
1765 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1766
1767 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1768 op.getMode() != ResizeMode::BILINEAR)
1769 return rewriter.notifyMatchFailure(
1770 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1771
1772 if (inputTy == resultTy) {
1773 rewriter.replaceOp(op, input);
1774 return success();
1775 }
1776
1777 SmallVector<int64_t> scale;
1778 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1779 return failure();
1780 }
1781
1782 // Collapse the unit width and height away.
1783 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1784 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1785 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1786 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1787 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1788
1789 auto collapseTy =
1790 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1791 inputTy.getElementType());
1792 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1793 reassociationMap);
1794
1795 // Get any dynamic shapes that appear in the input format.
1796 llvm::SmallVector<Value> outputDynSize;
1797 if (inputTy.isDynamicDim(0))
1798 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1799 if (inputTy.isDynamicDim(3))
1800 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1801
1802 // Generate the elementwise operation for casting scaling the input value.
1803 auto genericTy = collapseTy.clone(resultTy.getElementType());
1804 Value empty =
1805 tensor::EmptyOp::create(builder, genericTy.getShape(),
1806 resultTy.getElementType(), outputDynSize);
1807 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1808 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1809 utils::IteratorType::parallel);
1810
1811 auto generic = linalg::GenericOp::create(
1812 builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1813 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1814 [=](OpBuilder &b, Location loc, ValueRange args) {
1815 Value value = args[0];
1816 // This is the quantized case.
1817 if (inputTy.getElementType() != resultTy.getElementType()) {
1818 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1819 value);
1820
1821 if (isBilinear && scale[0] != 0) {
1822 Value scaleY = arith::ConstantOp::create(
1823 b, loc, b.getI32IntegerAttr(scale[0]));
1824 value = arith::MulIOp::create(b, loc, value, scaleY);
1825 }
1826
1827 if (isBilinear && scale[2] != 0) {
1828 Value scaleX = arith::ConstantOp::create(
1829 b, loc, b.getI32IntegerAttr(scale[2]));
1830 value = arith::MulIOp::create(b, loc, value, scaleX);
1831 }
1832 }
1833
1834 linalg::YieldOp::create(b, loc, value);
1835 });
1836
1837 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1838 op, resultTy, generic.getResults()[0], reassociationMap);
1839 return success();
1840 }
1841};
1842
1843// TOSA resize with width or height of 1 may be broadcasted to a wider
1844// dimension. This is done by materializing a new tosa.resize without
1845// the broadcasting behavior, and an explicit broadcast afterwards.
1846class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1847public:
1848 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1849
1850 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1851 PatternRewriter &rewriter) const final {
1852 Location loc = op.getLoc();
1853 ImplicitLocOpBuilder builder(loc, rewriter);
1854 auto input = op.getInput();
1855 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1856 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1857
1858 if (!inputTy || !resultTy)
1859 return rewriter.notifyMatchFailure(op,
1860 "requires ranked input/output types");
1861
1862 auto batch = inputTy.getDimSize(0);
1863 auto channels = inputTy.getDimSize(3);
1864 auto inputH = inputTy.getDimSize(1);
1865 auto inputW = inputTy.getDimSize(2);
1866 auto outputH = resultTy.getDimSize(1);
1867 auto outputW = resultTy.getDimSize(2);
1868
1869 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1870 return rewriter.notifyMatchFailure(
1871 op, "tosa.resize has no broadcasting behavior");
1872
1873 // For any dimension that is broadcastable we generate a width of 1
1874 // on the output.
1875 llvm::SmallVector<int64_t> resizeShape;
1876 resizeShape.push_back(batch);
1877 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1878 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1879 resizeShape.push_back(channels);
1880
1881 auto resizeTy = resultTy.clone(resizeShape);
1882 auto resize =
1883 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1884 op.getOffset(), op.getBorder(), op.getMode());
1885
1886 // Collapse an unit result dims.
1887 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1888 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1889 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1890 if (inputH != 1)
1891 reassociationMap.push_back({});
1892 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1893 if (inputW != 1)
1894 reassociationMap.push_back({});
1895 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1896
1897 llvm::SmallVector<int64_t> collapseShape = {batch};
1898 if (inputH != 1)
1899 collapseShape.push_back(outputH);
1900 if (inputW != 1)
1901 collapseShape.push_back(outputW);
1902 collapseShape.push_back(channels);
1903
1904 auto collapseTy = resultTy.clone(collapseShape);
1905 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1906 resize, reassociationMap);
1907
1908 // Broadcast the collapsed shape to the output result.
1909 llvm::SmallVector<Value> outputDynSize;
1910 if (inputTy.isDynamicDim(0))
1911 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1912 if (inputTy.isDynamicDim(3))
1913 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1914
1915 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1916 utils::IteratorType::parallel);
1917 Value empty = tensor::EmptyOp::create(
1918 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1919
1920 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1921 if (inputH != 1)
1922 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1923 if (inputW != 1)
1924 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1925 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1926
1927 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1928 inputExprs, rewriter.getContext());
1929
1930 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1931 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1932 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1933 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1934 [=](OpBuilder &b, Location loc, ValueRange args) {
1935 Value value = args[0];
1936 linalg::YieldOp::create(b, loc, value);
1937 });
1938
1939 return success();
1940 }
1941};
1942
1943class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1944public:
1945 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1946
1947 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1948 PatternRewriter &rewriter) const final {
1949 Location loc = op.getLoc();
1950 ImplicitLocOpBuilder b(loc, rewriter);
1951 auto input = op.getInput();
1952 auto inputTy = cast<ShapedType>(input.getType());
1953 auto resultTy = cast<ShapedType>(op.getType());
1954 auto resultETy = resultTy.getElementType();
1955
1956 bool floatingPointMode = isa<FloatType>(resultETy);
1957 auto floatTy = resultETy;
1958
1959 auto imageH = inputTy.getShape()[1];
1960 auto imageW = inputTy.getShape()[2];
1961
1962 auto dynamicDimsOr =
1963 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1964 if (!dynamicDimsOr.has_value())
1965 return rewriter.notifyMatchFailure(
1966 op, "unable to get dynamic dimensions of tosa.resize");
1967
1968 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1969 op.getMode() != ResizeMode::BILINEAR)
1970 return rewriter.notifyMatchFailure(
1971 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1972
1973 SmallVector<AffineMap, 2> affineMaps = {
1974 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1975 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1976 resultETy, *dynamicDimsOr);
1977 auto genericOp = linalg::GenericOp::create(
1978 b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1979 getNParallelLoopsAttrs(resultTy.getRank()));
1980 Value resize = genericOp.getResult(0);
1981
1982 {
1983 OpBuilder::InsertionGuard regionGuard(b);
1984 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1985 TypeRange({resultETy}), loc);
1986 Value batch = linalg::IndexOp::create(b, 0);
1987 Value y = linalg::IndexOp::create(b, 1);
1988 Value x = linalg::IndexOp::create(b, 2);
1989 Value channel = linalg::IndexOp::create(b, 3);
1990
1991 Value zeroI32 =
1992 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1993 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1994 Value hMax =
1995 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1996 Value wMax =
1997 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1998
1999 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
2000 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
2001
2002 SmallVector<int64_t> scale, offset, border;
2003 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
2004 !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
2005 !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
2006 return rewriter.notifyMatchFailure(
2007 op, "tosa.resize scale/offset/border should have compile time "
2008 "constant values.");
2009 }
2010
2011 Value yScaleN, yScaleD, xScaleN, xScaleD;
2012 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2013 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2014 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2015 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2016
2017 Value yOffset, xOffset, yBorder, xBorder;
2018 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2019 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2020 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2021 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2022
2023 // Compute the ix and dx values for both the X and Y dimensions.
2024 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2025 Value scaleN, Value scaleD, Value offset,
2026 int size, ImplicitLocOpBuilder &b) {
2027 if (size == 1) {
2028 index = zeroI32;
2029 delta = zeroFp;
2030 return;
2031 }
2032 // x = x * scale_d + offset;
2033 // ix = floor(x / scale_n)
2034 Value val = arith::MulIOp::create(b, in, scaleD);
2035 val = arith::AddIOp::create(b, val, offset);
2036 index = arith::FloorDivSIOp::create(b, val, scaleN);
2037
2038 // rx = x - ix * scale_n (x % scale_n, if values are positive)
2039 Value scaledIndex = arith::MulIOp::create(b, index, scaleN);
2040 Value r = arith::SubIOp::create(b, val, scaledIndex);
2041 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2042
2043 // dx = rx / scale_n
2044 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2045 delta = arith::DivFOp::create(b, rFp, scaleNfp);
2046 };
2047
2048 // Compute the ix and dx values for the X and Y dimensions - int case.
2049 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2050 Value scaleN, Value scaleD, Value offset,
2051 int size, ImplicitLocOpBuilder &b) {
2052 if (size == 1) {
2053 index = zeroI32;
2054 delta = zeroI32;
2055 return;
2056 }
2057 // x = x * scale_d + offset;
2058 // ix = floor(x / scale_n)
2059 // dx = x - ix * scale_n;
2060 Value val = arith::MulIOp::create(b, in, scaleD);
2061 val = arith::AddIOp::create(b, val, offset);
2062 index = arith::FloorDivSIOp::create(b, val, scaleN);
2063 delta = arith::MulIOp::create(b, index, scaleN);
2064 delta = arith::SubIOp::create(b, val, delta);
2065 };
2066
2067 Value ix, iy, dx, dy;
2068 if (floatingPointMode) {
2069 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2070 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2071 } else {
2072 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2073 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2074 }
2075
2076 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2077 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2078
2079 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2080 Value max, int size,
2081 ImplicitLocOpBuilder &b) -> Value {
2082 if (size == 1) {
2084 }
2085
2086 Value pred;
2087 if (floatingPointMode) {
2088 auto h =
2089 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2090 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2091 } else {
2092 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2093 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2094 dvalDouble, scale);
2095 }
2096
2097 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2098 val = arith::AddIOp::create(b, val, offset);
2099 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
2100 return arith::IndexCastOp::create(b, b.getIndexType(), val);
2101 };
2102
2103 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2104 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2105
2106 Value result = tensor::ExtractOp::create(
2107 b, input, ValueRange{batch, iy, ix, channel});
2108
2109 linalg::YieldOp::create(b, result);
2110 } else {
2111 // The mode here must be BILINEAR.
2112 assert(op.getMode() == ResizeMode::BILINEAR);
2113
2114 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2115
2116 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
2117 Value max, ImplicitLocOpBuilder &b) {
2118 val0 = in;
2119 val1 = arith::AddIOp::create(b, val0, oneVal);
2120 val0 =
2121 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
2122 val1 =
2123 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
2124 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2125 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2126 };
2127
2128 // Linalg equivalent to the section below:
2129 // int16_t iy0 = apply_max(iy, 0);
2130 // int16_t iy1 = apply_min(iy + 1, IH - 1);
2131 // int16_t ix0 = apply_max(ix, 0);
2132 // int16_t ix1 = apply_min(ix + 1, IW - 1);
2133 Value x0, x1, y0, y1;
2134 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2135 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2136
2137 Value y0x0 = tensor::ExtractOp::create(
2138 b, input, ValueRange{batch, y0, x0, channel});
2139 Value y0x1 = tensor::ExtractOp::create(
2140 b, input, ValueRange{batch, y0, x1, channel});
2141 Value y1x0 = tensor::ExtractOp::create(
2142 b, input, ValueRange{batch, y1, x0, channel});
2143 Value y1x1 = tensor::ExtractOp::create(
2144 b, input, ValueRange{batch, y1, x1, channel});
2145
2146 if (floatingPointMode) {
2147 auto oneVal =
2148 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2149 auto interpolate = [&](Value val0, Value val1, Value delta,
2150 int inputSize,
2151 ImplicitLocOpBuilder &b) -> Value {
2152 if (inputSize == 1)
2153 return val0;
2154 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2155 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2156 Value mul1 = arith::MulFOp::create(b, val1, delta);
2157 return arith::AddFOp::create(b, mul0, mul1);
2158 };
2159
2160 // Linalg equivalent to the section below:
2161 // topAcc = v00 * (unit_x - dx);
2162 // topAcc += v01 * dx;
2163 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2164
2165 // Linalg equivalent to the section below:
2166 // bottomAcc = v10 * (unit_x - dx);
2167 // bottomAcc += v11 * dx;
2168 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2169
2170 // Linalg equivalent to the section below:
2171 // result = topAcc * (unit_y - dy) + bottomAcc * dy
2172 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2173 linalg::YieldOp::create(b, result);
2174 } else {
2175 // Perform in quantized space.
2176 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2177 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2178 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2179 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2180
2181 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2182 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2183 dx = arith::ExtSIOp::create(b, resultETy, dx);
2184 dy = arith::ExtSIOp::create(b, resultETy, dy);
2185 }
2186
2187 Value yScaleNExt = yScaleN;
2188 Value xScaleNExt = xScaleN;
2189
2190 const int64_t scaleBitwidth =
2191 xScaleN.getType().getIntOrFloatBitWidth();
2192 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2193 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2194 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2195 }
2196
2197 auto interpolate = [](Value val0, Value val1, Value weight1,
2198 Value scale, int inputSize,
2199 ImplicitLocOpBuilder &b) -> Value {
2200 if (inputSize == 1)
2201 return arith::MulIOp::create(b, val0, scale);
2202 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2203 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2204 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2205 return arith::AddIOp::create(b, mul0, mul1);
2206 };
2207
2208 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2209 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2210 Value result =
2211 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2212 linalg::YieldOp::create(b, result);
2213 }
2214 }
2215 }
2216
2217 rewriter.replaceOp(op, resize);
2218 return success();
2219 }
2220};
2221
2222// At the codegen level any identity operations should be removed. Any cases
2223// where identity is load-bearing (e.g. cross device computation) should be
2224// handled before lowering to codegen.
2225template <typename SrcOp>
2226class IdentityNConverter : public OpRewritePattern<SrcOp> {
2227public:
2228 using OpRewritePattern<SrcOp>::OpRewritePattern;
2229
2230 LogicalResult matchAndRewrite(SrcOp op,
2231 PatternRewriter &rewriter) const final {
2232 rewriter.replaceOp(op, op.getOperation()->getOperands());
2233 return success();
2234 }
2235};
2236
2237template <typename SrcOp>
2238class ReduceConverter : public OpRewritePattern<SrcOp> {
2239public:
2240 using OpRewritePattern<SrcOp>::OpRewritePattern;
2241
2242 LogicalResult matchAndRewrite(SrcOp reduceOp,
2243 PatternRewriter &rewriter) const final {
2244 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2245 }
2246};
2247
2248class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2249public:
2250 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2251
2252 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2253 PatternRewriter &rewriter) const final {
2254 auto loc = op.getLoc();
2255 Value input = op.getInput1();
2256 auto inputTy = cast<ShapedType>(input.getType());
2257 auto resultTy = cast<ShapedType>(op.getType());
2258 auto axis = op.getAxis();
2259
2260 SmallVector<Value> dynDims;
2261 for (int i = 0; i < inputTy.getRank(); i++) {
2262 if (inputTy.isDynamicDim(i)) {
2263 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2264 }
2265 }
2266
2267 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2268
2269 // First fill the output buffer with the init value.
2270 auto emptyTensor = tensor::EmptyOp::create(
2271 rewriter, loc, inputTy.getShape(),
2272 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2273 .getResult();
2274 SmallVector<AffineMap, 2> affineMaps = {
2275 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2276
2277 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2278 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2279 getNParallelLoopsAttrs(resultTy.getRank()),
2280 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2281 llvm::SmallVector<Value> indices;
2282 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2283 Value index =
2284 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2285 if (i == axis) {
2286 auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2287 auto sizeMinusOne =
2288 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2289 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2290 index);
2291 }
2292
2293 indices.push_back(index);
2294 }
2295
2296 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2297 input, indices);
2298 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2299 extract.getResult());
2300 });
2301 return success();
2302 }
2303};
2304
2305// This converter translate a tile operation to a reshape, broadcast, reshape.
2306// The first reshape minimally expands each tiled dimension to include a
2307// proceding size-1 dim. This dim is then broadcasted to the appropriate
2308// multiple.
2309struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2310 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2311
2312 LogicalResult
2313 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2314 ConversionPatternRewriter &rewriter) const override {
2315 auto loc = op.getLoc();
2316 auto input = op.getInput1();
2317 auto inputTy = cast<ShapedType>(input.getType());
2318 auto inputShape = inputTy.getShape();
2319 auto resultTy = cast<ShapedType>(op.getType());
2320 auto elementTy = inputTy.getElementType();
2321 int64_t rank = inputTy.getRank();
2322
2323 SmallVector<int64_t> multiples;
2324 if (failed(op.getConstantMultiples(multiples)))
2325 return failure();
2326
2327 // Broadcast the newly added dimensions to their appropriate multiple.
2328 SmallVector<int64_t, 2> genericShape;
2329 for (int i = 0; i < rank; i++) {
2330 int64_t dim = multiples[i];
2331 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2332 genericShape.push_back(inputShape[i]);
2333 }
2334
2335 SmallVector<Value> dynDims;
2336 for (int i = 0; i < inputTy.getRank(); i++) {
2337 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2338 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2339 }
2340 }
2341
2342 auto emptyTensor = tensor::EmptyOp::create(
2343 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2344
2345 // We needs to map the input shape to the non-broadcasted dimensions.
2346 SmallVector<AffineExpr, 4> dimExprs;
2347 dimExprs.reserve(rank);
2348 for (unsigned i = 0; i < rank; ++i)
2349 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2350
2351 auto readAffineMap =
2352 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2353 rewriter.getContext());
2354
2355 SmallVector<AffineMap, 2> affineMaps = {
2356 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2357
2358 auto genericOp = linalg::GenericOp::create(
2359 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2360 ValueRange{emptyTensor}, affineMaps,
2361 getNParallelLoopsAttrs(genericShape.size()),
2362 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2363 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2364 });
2365
2366 auto shapeValue = getTosaConstShape(
2367 rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2368 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2369 op, resultTy, genericOp.getResult(0), shapeValue);
2370 return success();
2371 }
2372};
2373
2374// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2375// op, producing two output buffers.
2376//
2377// The first output buffer contains the index of the found maximum value. It is
2378// initialized to 0 and is resulting integer type.
2379//
2380// The second output buffer contains the maximum value found. It is initialized
2381// to the minimum representable value of the input element type. After being
2382// populated by indexed_generic, this buffer is disgarded as only the index is
2383// requested.
2384//
2385// The indexed_generic op updates both the maximum value and index if the
2386// current value exceeds the running max.
2387class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2388public:
2389 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2390
2391 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2392 PatternRewriter &rewriter) const final {
2393 auto loc = argmaxOp.getLoc();
2394 Value input = argmaxOp.getInput();
2395 auto inputTy = cast<ShapedType>(input.getType());
2396 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2397 auto inElementTy = inputTy.getElementType();
2398 auto outElementTy = resultTy.getElementType();
2399 int axis = argmaxOp.getAxis();
2400 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2401
2402 if (!isa<IntegerType>(outElementTy))
2403 return rewriter.notifyMatchFailure(
2404 argmaxOp,
2405 "tosa.arg_max to linalg.* requires integer-like result type");
2406
2407 SmallVector<Value> dynDims;
2408 for (int i = 0; i < inputTy.getRank(); i++) {
2409 if (inputTy.isDynamicDim(i) && i != axis) {
2410 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2411 }
2412 }
2413
2414 // First fill the output buffer for the index.
2415 auto emptyTensorIdx =
2416 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2417 outElementTy, dynDims)
2418 .getResult();
2419 auto fillValueIdx = arith::ConstantOp::create(
2420 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2421 auto filledTensorIdx =
2422 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2423 ValueRange{emptyTensorIdx})
2424 .result();
2425
2426 // Second fill the output buffer for the running max.
2427 auto emptyTensorMax =
2428 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2429 dynDims)
2430 .getResult();
2431 auto fillValueMaxAttr =
2432 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2433
2434 if (!fillValueMaxAttr)
2435 return rewriter.notifyMatchFailure(
2436 argmaxOp, "unsupported tosa.argmax element type");
2437
2438 auto fillValueMax =
2439 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2440 auto filledTensorMax =
2441 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2442 ValueRange{emptyTensorMax})
2443 .result();
2444
2445 // We need to reduce along the arg-max axis, with parallel operations along
2446 // the rest.
2447 SmallVector<utils::IteratorType, 4> iteratorTypes;
2448 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2449 iteratorTypes[axis] = utils::IteratorType::reduction;
2450
2451 SmallVector<AffineExpr, 2> srcExprs;
2452 SmallVector<AffineExpr, 2> dstExprs;
2453 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2454 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2455 if (axis != i)
2456 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2457 }
2458
2459 bool didEncounterError = false;
2460 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2461 rewriter.getContext());
2462 auto linalgOp = linalg::GenericOp::create(
2463 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2464 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2465 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2466 ValueRange blockArgs) {
2467 auto newValue = blockArgs[0];
2468 auto oldIndex = blockArgs[1];
2469 auto oldValue = blockArgs[2];
2470
2471 Value newIndex = arith::IndexCastOp::create(
2472 rewriter, nestedLoc, oldIndex.getType(),
2473 linalg::IndexOp::create(rewriter, loc, axis));
2474
2475 Value predicate;
2476 if (isa<FloatType>(inElementTy)) {
2477 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2478 // Only update index & max value for non NaN values. If all
2479 // values are NaNs, the initial index will be return which is 0.
2480 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2481 arith::CmpFPredicate::OGT,
2482 newValue, oldValue);
2483 } else {
2484 // Update max value if either of the following is true:
2485 // - new value is bigger
2486 // - cur max is not NaN and new value is NaN
2487 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2488 arith::CmpFPredicate::UGT,
2489 newValue, oldValue);
2490 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2491 arith::CmpFPredicate::ORD,
2492 oldValue, oldValue);
2493 predicate = arith::AndIOp::create(
2494 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2495 }
2496 } else if (isa<IntegerType>(inElementTy)) {
2497 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2498 arith::CmpIPredicate::sgt,
2499 newValue, oldValue);
2500 } else {
2501 didEncounterError = true;
2502 return;
2503 }
2504
2505 auto resultMax = arith::SelectOp::create(
2506 rewriter, nestedLoc, predicate, newValue, oldValue);
2507 auto resultIndex = arith::SelectOp::create(
2508 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2509 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2510 ValueRange({resultIndex, resultMax}));
2511 });
2512
2513 if (didEncounterError)
2514 return rewriter.notifyMatchFailure(
2515 argmaxOp, "unsupported tosa.argmax element type");
2516
2517 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2518 return success();
2519 }
2520};
2521
2522class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2523public:
2524 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2525 LogicalResult
2526 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2527 ConversionPatternRewriter &rewriter) const final {
2528 auto input = adaptor.getOperands()[0];
2529 auto indices = adaptor.getOperands()[1];
2530
2531 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2532 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2533 if (!valuesTy || !resultTy)
2534 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2535
2536 auto dynamicDims = inferDynamicDimsForGather(
2537 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2538
2539 auto resultElementTy = resultTy.getElementType();
2540
2541 auto loc = op.getLoc();
2542 auto emptyTensor =
2543 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2544 resultElementTy, dynamicDims)
2545 .getResult();
2546
2547 SmallVector<AffineMap, 2> affineMaps = {
2549 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2550 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2551 rewriter.getContext()),
2552 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2553
2554 auto genericOp = linalg::GenericOp::create(
2555 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2556 ValueRange{emptyTensor}, affineMaps,
2557 getNParallelLoopsAttrs(resultTy.getRank()),
2558 [&](OpBuilder &b, Location loc, ValueRange args) {
2559 auto indexValue = args[0];
2560 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2561 Value index1 = arith::IndexCastOp::create(
2562 rewriter, loc, rewriter.getIndexType(), indexValue);
2563 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2564 Value extract = tensor::ExtractOp::create(
2565 rewriter, loc, input, ValueRange{index0, index1, index2});
2566 linalg::YieldOp::create(rewriter, loc, extract);
2567 });
2568 rewriter.replaceOp(op, genericOp.getResult(0));
2569 return success();
2570 }
2571
2572 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2573 Location loc,
2574 Value values,
2575 Value indices) {
2576 llvm::SmallVector<Value> results;
2577
2578 auto addDynamicDimension = [&](Value source, int64_t dim) {
2579 auto sz = tensor::getMixedSize(builder, loc, source, dim);
2580 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2581 results.push_back(dimValue);
2582 };
2583
2584 addDynamicDimension(values, 0);
2585 addDynamicDimension(indices, 1);
2586 addDynamicDimension(values, 2);
2587 return results;
2588 }
2589};
2590
2591// Lowerings the TableOp to a series of gathers and numerica operations. This
2592// includes interpolation between the high/low values. For the I8 varient, this
2593// simplifies to a single gather operation.
2594class TableConverter : public OpRewritePattern<tosa::TableOp> {
2595public:
2596 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2597
2598 LogicalResult matchAndRewrite(tosa::TableOp op,
2599 PatternRewriter &rewriter) const final {
2600 auto loc = op.getLoc();
2601 Value input = op.getInput1();
2602 Value table = op.getTable();
2603 auto inputTy = cast<ShapedType>(input.getType());
2604 auto tableTy = cast<ShapedType>(table.getType());
2605 auto resultTy = cast<ShapedType>(op.getType());
2606
2607 auto inputElementTy = inputTy.getElementType();
2608 auto tableElementTy = tableTy.getElementType();
2609 auto resultElementTy = resultTy.getElementType();
2610
2611 SmallVector<Value> dynDims;
2612 for (int i = 0; i < resultTy.getRank(); ++i) {
2613 if (inputTy.isDynamicDim(i)) {
2614 dynDims.push_back(
2615 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2616 }
2617 }
2618
2619 auto emptyTensor =
2620 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2621 resultElementTy, dynDims)
2622 .getResult();
2623
2624 SmallVector<AffineMap, 2> affineMaps = {
2625 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2626 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2627
2628 auto genericOp = linalg::GenericOp::create(
2629 rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2630 affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2631 rewriter.replaceOp(op, genericOp.getResult(0));
2632
2633 {
2634 OpBuilder::InsertionGuard regionGuard(rewriter);
2635 Block *block = rewriter.createBlock(
2636 &genericOp.getRegion(), genericOp.getRegion().end(),
2637 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2638
2639 auto inputValue = block->getArgument(0);
2640 rewriter.setInsertionPointToStart(block);
2641 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2642 resultElementTy.isInteger(8)) {
2643 Value index = arith::IndexCastOp::create(
2644 rewriter, loc, rewriter.getIndexType(), inputValue);
2645 Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2646 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2647 index, offset);
2648 Value extract =
2649 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2650 linalg::YieldOp::create(rewriter, loc, extract);
2651 return success();
2652 }
2653
2654 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2655 resultElementTy.isInteger(32)) {
2656 Value extend = arith::ExtSIOp::create(
2657 rewriter, loc, rewriter.getI32Type(), inputValue);
2658
2659 auto offset = arith::ConstantOp::create(
2660 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2661 auto seven = arith::ConstantOp::create(rewriter, loc,
2662 rewriter.getI32IntegerAttr(7));
2663 auto one = arith::ConstantOp::create(rewriter, loc,
2664 rewriter.getI32IntegerAttr(1));
2665 auto b1111111 = arith::ConstantOp::create(
2666 rewriter, loc, rewriter.getI32IntegerAttr(127));
2667
2668 // Compute the index and fractional part from the input value:
2669 // value = value + 32768
2670 // index = value >> 7;
2671 // fraction = 0x01111111 & value
2672 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2673 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2674 Value fraction =
2675 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2676
2677 // Extract the base and next values from the table.
2678 // base = (int32_t) table[index];
2679 // next = (int32_t) table[index + 1];
2680 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2681
2682 index = arith::IndexCastOp::create(rewriter, loc,
2683 rewriter.getIndexType(), index);
2684 indexPlusOne = arith::IndexCastOp::create(
2685 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2686
2687 Value base =
2688 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2689 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2690 ValueRange{indexPlusOne});
2691
2692 base =
2693 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2694 next =
2695 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2696
2697 // Use the fractional part to interpolate between the input values:
2698 // result = (base << 7) + (next - base) * fraction
2699 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2700 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2701 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2702 Value result =
2703 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2704
2705 linalg::YieldOp::create(rewriter, loc, result);
2706
2707 return success();
2708 }
2709 }
2710
2711 return rewriter.notifyMatchFailure(
2712 op, "unable to create body for tosa.table op");
2713 }
2714};
2715
2716struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2717 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2718
2719 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2720
2721 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2722 OpFoldResult ofr) {
2723 auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2724 auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2725
2726 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2727 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2728 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2729 return getAsOpFoldResult(plusOne);
2730 }
2731
2732 static RankedTensorType
2733 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2734 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2735 // Get [N, H, W]
2736 auto dims = tensor::getMixedSizes(builder, loc, input);
2737
2738 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2739 // output tensors.
2740 dims[2] = halfPlusOne(builder, loc, dims[2]);
2741
2742 llvm::SmallVector<int64_t, 3> staticSizes;
2743 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2744
2745 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2746 return RankedTensorType::get(staticSizes, elementType);
2747 }
2748
2749 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2750 RankedTensorType type,
2751 llvm::ArrayRef<Value> dynamicSizes) {
2752 auto emptyTensor =
2753 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2754 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2755 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2756 auto filledTensor =
2757 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2758 ValueRange{emptyTensor})
2759 .result();
2760 return filledTensor;
2761 }
2762
2763 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2764 FloatType type, Value value) {
2765 auto integerVal = arith::IndexCastUIOp::create(
2766 builder, loc,
2767 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2768 : builder.getI32Type(),
2769 value);
2770
2771 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2772 }
2773
2774 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2775 FloatType type, int64_t index) {
2776 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2777 return castIndexToFloat(builder, loc, type, indexVal);
2778 }
2779
2780 template <typename... Args>
2781 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2782 Args... args) {
2783 return {builder.getAffineDimExpr(args)...};
2784 }
2785
2786 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2787 PatternRewriter &rewriter) const override {
2788 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2789 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2790 return rewriter.notifyMatchFailure(rfft2d,
2791 "only supports ranked tensors");
2792 }
2793
2794 auto loc = rfft2d.getLoc();
2795 auto input = rfft2d.getInputReal();
2796 auto elementType =
2797 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2798 if (!elementType)
2799 return rewriter.notifyMatchFailure(rfft2d,
2800 "only supports float element types");
2801
2802 // Compute the output type and set of dynamic sizes
2803 llvm::SmallVector<Value> dynamicSizes;
2804 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2805
2806 // Iterator types for the linalg.generic implementation
2807 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2808 utils::IteratorType::parallel, utils::IteratorType::parallel,
2809 utils::IteratorType::parallel, utils::IteratorType::reduction,
2810 utils::IteratorType::reduction};
2811
2812 // Inputs/outputs to the linalg.generic implementation
2813 llvm::SmallVector<Value> genericOpInputs = {input};
2814 llvm::SmallVector<Value> genericOpOutputs = {
2815 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2816 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2817
2818 // Indexing maps for input and output tensors
2819 auto indexingMaps = AffineMap::inferFromExprList(
2820 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2821 affineDimsExpr(rewriter, 0, 1, 2),
2822 affineDimsExpr(rewriter, 0, 1, 2)},
2823 rewriter.getContext());
2824
2825 // Width and height dimensions of the original input.
2826 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2827 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2828
2829 // Constants and dimension sizes
2830 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2831 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2832 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2833 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2834
2835 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2836 Value valReal = args[0];
2837 Value sumReal = args[1];
2838 Value sumImag = args[2];
2839
2840 // Indices for angle computation
2841 Value oy = linalg::IndexOp::create(builder, loc, 1);
2842 Value ox = linalg::IndexOp::create(builder, loc, 2);
2843 Value iy = linalg::IndexOp::create(builder, loc, 3);
2844 Value ix = linalg::IndexOp::create(builder, loc, 4);
2845
2846 // Calculating angle without integer parts of components as sin/cos are
2847 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2848 // / W);
2849 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2850 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2851
2852 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2853 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2854
2855 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2856 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2857
2858 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2859 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2860 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2861 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2862
2863 // realComponent = valReal * cos(angle)
2864 // imagComponent = valReal * sin(angle)
2865 auto cosAngle = math::CosOp::create(builder, loc, angle);
2866 auto sinAngle = math::SinOp::create(builder, loc, angle);
2867 auto realComponent =
2868 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2869 auto imagComponent =
2870 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2871
2872 // outReal = sumReal + realComponent
2873 // outImag = sumImag - imagComponent
2874 auto outReal =
2875 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2876 auto outImag =
2877 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2878
2879 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2880 };
2881
2882 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2883 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2884 indexingMaps, iteratorTypes, buildBody);
2885
2886 return success();
2887 }
2888};
2889
2890struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2892
2893 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2894 PatternRewriter &rewriter) const override {
2895 if (!llvm::all_of(fft2d->getOperandTypes(),
2896 RFFT2dConverter::isRankedTensor) ||
2897 !llvm::all_of(fft2d->getResultTypes(),
2898 RFFT2dConverter::isRankedTensor)) {
2899 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2900 }
2901
2902 Location loc = fft2d.getLoc();
2903 Value input_real = fft2d.getInputReal();
2904 Value input_imag = fft2d.getInputImag();
2905 BoolAttr inverse = fft2d.getInverseAttr();
2906
2907 auto real_el_ty = cast<FloatType>(
2908 cast<ShapedType>(input_real.getType()).getElementType());
2909 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2910 cast<ShapedType>(input_imag.getType()).getElementType());
2911
2912 assert(real_el_ty == imag_el_ty);
2913
2914 // Compute the output type and set of dynamic sizes
2915 SmallVector<Value> dynamicSizes;
2916
2917 // Get [N, H, W]
2918 auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2919
2920 SmallVector<int64_t, 3> staticSizes;
2921 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2922
2923 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2924
2925 // Iterator types for the linalg.generic implementation
2926 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2927 utils::IteratorType::parallel, utils::IteratorType::parallel,
2928 utils::IteratorType::parallel, utils::IteratorType::reduction,
2929 utils::IteratorType::reduction};
2930
2931 // Inputs/outputs to the linalg.generic implementation
2932 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2933 SmallVector<Value> genericOpOutputs = {
2934 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2935 dynamicSizes),
2936 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2937 dynamicSizes)};
2938
2939 // Indexing maps for input and output tensors
2940 auto indexingMaps = AffineMap::inferFromExprList(
2941 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2942 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2943 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2944 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2945 rewriter.getContext());
2946
2947 // Width and height dimensions of the original input.
2948 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2949 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2950
2951 // Constants and dimension sizes
2952 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2953 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2954 Value constH =
2955 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2956 Value constW =
2957 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2958
2959 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2960 Value valReal = args[0];
2961 Value valImag = args[1];
2962 Value sumReal = args[2];
2963 Value sumImag = args[3];
2964
2965 // Indices for angle computation
2966 Value oy = linalg::IndexOp::create(builder, loc, 1);
2967 Value ox = linalg::IndexOp::create(builder, loc, 2);
2968 Value iy = linalg::IndexOp::create(builder, loc, 3);
2969 Value ix = linalg::IndexOp::create(builder, loc, 4);
2970
2971 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2972 // ox) % W ) / W);
2973 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2974 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2975
2976 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2977 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2978
2979 auto iyRemFloat =
2980 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2981 auto ixRemFloat =
2982 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2983
2984 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2985 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2986
2987 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2988 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2989
2990 if (inverse.getValue()) {
2991 angle = arith::MulFOp::create(
2992 builder, loc, angle,
2993 arith::ConstantOp::create(rewriter, loc,
2994 rewriter.getFloatAttr(real_el_ty, -1.0)));
2995 }
2996
2997 // realComponent = val_real * cos(a) + val_imag * sin(a);
2998 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2999 auto cosAngle = math::CosOp::create(builder, loc, angle);
3000 auto sinAngle = math::SinOp::create(builder, loc, angle);
3001
3002 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3003 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3004 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3005
3006 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3007 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3008
3009 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3010
3011 // outReal = sumReal + realComponent
3012 // outImag = sumImag - imagComponent
3013 auto outReal =
3014 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3015 auto outImag =
3016 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3017
3018 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
3019 };
3020
3021 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
3022 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3023 indexingMaps, iteratorTypes, buildBody);
3024
3025 return success();
3026 }
3027};
3028
3029} // namespace
3030
3032 const TypeConverter &converter, RewritePatternSet *patterns) {
3033
3034 // We have multiple resize coverters to handle degenerate cases.
3035 patterns->add<GenericResizeConverter>(patterns->getContext(),
3036 /*benefit=*/100);
3037 patterns->add<ResizeUnaryConverter>(patterns->getContext(),
3038 /*benefit=*/200);
3039 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
3040 /*benefit=*/300);
3041
3042 patterns->add<
3043 // clang-format off
3044 PointwiseConverter<tosa::AddOp>,
3045 PointwiseConverter<tosa::SubOp>,
3046 PointwiseConverter<tosa::MulOp>,
3047 PointwiseConverter<tosa::IntDivOp>,
3048 PointwiseConverter<tosa::NegateOp>,
3049 PointwiseConverter<tosa::PowOp>,
3050 PointwiseConverter<tosa::ReciprocalOp>,
3051 PointwiseConverter<tosa::RsqrtOp>,
3052 PointwiseConverter<tosa::LogOp>,
3053 PointwiseConverter<tosa::ExpOp>,
3054 PointwiseConverter<tosa::AbsOp>,
3055 PointwiseConverter<tosa::SinOp>,
3056 PointwiseConverter<tosa::CosOp>,
3057 PointwiseConverter<tosa::TanhOp>,
3058 PointwiseConverter<tosa::ErfOp>,
3059 PointwiseConverter<tosa::BitwiseAndOp>,
3060 PointwiseConverter<tosa::BitwiseOrOp>,
3061 PointwiseConverter<tosa::BitwiseNotOp>,
3062 PointwiseConverter<tosa::BitwiseXorOp>,
3063 PointwiseConverter<tosa::LogicalAndOp>,
3064 PointwiseConverter<tosa::LogicalNotOp>,
3065 PointwiseConverter<tosa::LogicalOrOp>,
3066 PointwiseConverter<tosa::LogicalXorOp>,
3067 PointwiseConverter<tosa::CastOp>,
3068 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3069 PointwiseConverter<tosa::LogicalRightShiftOp>,
3070 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3071 PointwiseConverter<tosa::ClzOp>,
3072 PointwiseConverter<tosa::SelectOp>,
3073 PointwiseConverter<tosa::GreaterOp>,
3074 PointwiseConverter<tosa::GreaterEqualOp>,
3075 PointwiseConverter<tosa::EqualOp>,
3076 PointwiseConverter<tosa::MaximumOp>,
3077 PointwiseConverter<tosa::MinimumOp>,
3078 PointwiseConverter<tosa::CeilOp>,
3079 PointwiseConverter<tosa::FloorOp>,
3080 PointwiseConverter<tosa::ClampOp>,
3081 PointwiseConverter<tosa::SigmoidOp>
3082 >(converter, patterns->getContext());
3083
3084 patterns->add<
3085 IdentityNConverter<tosa::IdentityOp>,
3086 ReduceConverter<tosa::ReduceAllOp>,
3087 ReduceConverter<tosa::ReduceAnyOp>,
3088 ReduceConverter<tosa::ReduceMinOp>,
3089 ReduceConverter<tosa::ReduceMaxOp>,
3090 ReduceConverter<tosa::ReduceSumOp>,
3091 ReduceConverter<tosa::ReduceProductOp>,
3092 ArgMaxConverter,
3093 GatherConverter,
3094 RescaleConverter,
3095 ReverseConverter,
3096 RFFT2dConverter,
3097 FFT2dConverter,
3098 TableConverter,
3099 TileConverter>(patterns->getContext());
3100 // clang-format on
3101}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
Definition Block.h:139
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
FloatType getF32Type()
Definition Builders.cpp:47
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
IntegerAttr getI8IntegerAttr(int8_t value)
Definition Builders.cpp:225
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...