MLIR 23.0.0git
TosaToLinalg.cpp
Go to the documentation of this file.
1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32
33#include <type_traits>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38// Helper function to materialize the semantically correct compare and select
39// operations given a binary operation with a specific NaN propagation mode.
40//
41// In the case of "PROPAGATE" semantics no compare and selection is required and
42// this function does nothing.
43//
44// In the case of "IGNORE" semantics this function materializes a comparison of
45// the current operands to the op which will return true for any NaN
46// argument and then selects between the non-NaN operation argument and the
47// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
48// code:
49//
50// In the case that the op is operating on non floating point types we ignore
51// the attribute completely, this is consistent with the TOSA spec which has
52// the following wording: "This attribute is ignored by non floating-point
53// types."
54//
55// binary<op>(lhs, rhs):
56// result = op(lhs, rhs)
57// if lhs == NaN return rhs
58// if rhs == NaN return lhs
59// return result
60template <typename OpTy>
61static Value
64 // NaN propagation has no meaning for non floating point types.
65 if (!isa<FloatType>(getElementTypeOrSelf(lhs)))
66 return result;
67
68 auto nanMode = op.getNanMode();
69 if (nanMode == NanPropagationMode::PROPAGATE)
70 return result;
71
72 // Unordered comparison of NaN against itself will always return true.
73 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
74 arith::CmpFPredicate::UNO, lhs, lhs);
75 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
76 arith::CmpFPredicate::UNO, rhs, rhs);
77 Value rhsOrResult =
78 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
79 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
80 rhsOrResult);
81}
82
84 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
85 ConversionPatternRewriter &rewriter) {
86 Location loc = op->getLoc();
87 auto elementTy =
88 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
89
90 // tosa::AbsOp
91 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
92 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93
94 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
95 auto zero = arith::ConstantOp::create(rewriter, loc,
96 rewriter.getZeroAttr(elementTy));
97 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
98 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
99 }
100
101 // tosa::AddOp
102 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
103 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104
105 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
106 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
107
108 // tosa::SubOp
109 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
110 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111
112 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
113 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
114
115 // tosa::IntDivOp
116 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
117 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
118
119 // tosa::ReciprocalOp
120 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 auto one =
122 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
123 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
124 }
125
126 // tosa::MulOp
127 if (isa<tosa::MulOp>(op)) {
128 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 DenseElementsAttr shiftElem;
130 bool shiftIsConstant = true;
131 int32_t shift = 0;
132 if (matchPattern(shiftVal, m_Constant(&shiftElem)))
133 shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
134 else
135 shiftIsConstant = false;
136
137 if (isa<FloatType>(elementTy)) {
138 if (shift != 0) {
139 (void)rewriter.notifyMatchFailure(op,
140 "Cannot have shift value for float");
141 return nullptr;
142 }
143 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
144 args[1]);
145 }
146
147 if (isa<IntegerType>(elementTy)) {
148 Value a = args[0];
149 Value b = args[1];
150
151 if (shift > 0 || !shiftIsConstant) {
152 Value shiftConst;
153 if (shiftIsConstant)
154 shiftConst = arith::ConstantIntOp::create(rewriter, loc, shift,
155 /*bitwidth=*/8);
156
157 if (!a.getType().isInteger(32))
158 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
159
160 if (!b.getType().isInteger(32))
161 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
162
163 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
165 RoundingMode::SINGLE_ROUND);
166 auto result =
167 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
168 b, shiftAmount, roundingAttr);
169
170 return result;
171 }
172
173 int aWidth = a.getType().getIntOrFloatBitWidth();
174 int bWidth = b.getType().getIntOrFloatBitWidth();
175 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
176
177 if (aWidth < cWidth)
178 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 if (bWidth < cWidth)
180 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
181
182 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
183 }
184 }
185
186 // tosa::NegateOp
187 if (isa<tosa::NegateOp>(op)) {
188 auto negate = cast<tosa::NegateOp>(op);
189
190 int64_t inZp = 0, outZp = 0;
191 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
192 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
193 bool hasInZp = !failed(maybeInZp);
194 bool hasOutZp = !failed(maybeOutZp);
195 if (hasInZp)
196 inZp = *maybeInZp;
197 if (hasOutZp)
198 outZp = *maybeOutZp;
199
200 if (isa<FloatType>(elementTy))
201 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202
203 if (isa<IntegerType>(elementTy)) {
204 Value zpAddValue;
205 Type intermediateType;
206 // Compute the maximum value that can occur in the intermediate buffer.
207 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
208 int intermediateBitWidth = 64;
209
210 if (hasInZp && hasOutZp) {
211 // Compute the maximum value that can occur in the intermediate buffer.
212 const int64_t zpAdd = inZp + outZp;
213 const int64_t maxValue =
214 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
215 std::abs(zpAdd) + 1;
216
217 // Convert that maximum value into the maximum bitwidth needed to
218 // represent it.
219 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
220 intermediateBitWidth = 16;
221 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
222 intermediateBitWidth = 32;
223 }
224
225 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
226 zpAddValue = arith::ConstantOp::create(
227 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
228 } else {
229 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
230 Value arg1 = args[1];
231 Value arg2 = args[2];
232 // Avoid verifier-invalid no-op sign-extends; only widen when needed.
233 if (arg1.getType() != intermediateType)
234 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
235 if (arg2.getType() != intermediateType)
236 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
237 zpAddValue =
238 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
239 }
240
241 // The negation can be applied by doing:
242 // outputValue = inZp + outZp - inputValue
243 Value ext = args[0];
244 if (ext.getType() != intermediateType)
245 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
246 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
247
248 // Clamp to the negation range.
250 rewriter, loc, intermediateType,
251 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
253 rewriter, loc, intermediateType,
254 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
255 auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
256
257 // Truncate to the final value, skipping no-op trunci when widths match.
258 if (clamp.getType() == elementTy)
259 return clamp;
260 return arith::TruncIOp::create(rewriter, loc, elementTy, clamp);
261 }
262 }
263
264 // tosa::BitwiseAndOp
265 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
266 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
267
268 // tosa::BitwiseOrOp
269 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
270 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
271
272 // tosa::BitwiseNotOp
273 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
274 auto allOnesAttr = rewriter.getIntegerAttr(
275 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
276 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
277 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
278 }
279
280 // tosa::BitwiseXOrOp
281 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
282 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
283
284 // tosa::LogicalLeftShiftOp
285 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
286 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
287
288 // tosa::LogicalRightShiftOp
289 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
290 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
291
292 // tosa::ArithmeticRightShiftOp
293 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
294 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
295 auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
296 if (!round) {
297 return result;
298 }
299
300 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
301 auto one = arith::ConstantOp::create(rewriter, loc,
302 IntegerAttr::get(elementTy, 1));
303 auto zero = arith::ConstantOp::create(rewriter, loc,
304 IntegerAttr::get(elementTy, 0));
305 auto i1zero =
306 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
307 auto i1one =
308 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
309
310 // Checking that input2 != 0
311 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
313
314 // Checking for the last bit of input1 to be 1
315 auto subtract =
316 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
317 auto shifted =
318 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
319 ->getResults();
320 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
322 auto isInputOdd =
323 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
324 // shifted, truncated, isInputOdd can be poison when input2 is 0.
325 auto shouldRound = arith::SelectOp::create(
326 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
327 auto extended =
328 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
329 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
330 }
331
332 // tosa::ClzOp
333 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
334 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
335 }
336
337 // tosa::LogicalAnd
338 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
339 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
340
341 // tosa::LogicalNot
342 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
343 auto one = arith::ConstantOp::create(rewriter, loc,
344 rewriter.getIntegerAttr(elementTy, 1));
345 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
346 }
347
348 // tosa::LogicalOr
349 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
350 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
351
352 // tosa::LogicalXor
353 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
354 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
355
356 // tosa::PowOp
357 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
358 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
359
360 // tosa::RsqrtOp
361 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
362 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
363
364 // tosa::LogOp
365 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
366 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
367
368 // tosa::ExpOp
369 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
370 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
371
372 // tosa::SinOp
373 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
374 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
375
376 // tosa::CosOp
377 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
378 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
379
380 // tosa::TanhOp
381 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
382 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
383
384 // tosa::ErfOp
385 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
386 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
387
388 // tosa::GreaterOp
389 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
390 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
391 args[0], args[1]);
392
393 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
394 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
395 args[0], args[1]);
396
397 // tosa::GreaterEqualOp
398 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
399 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
400 args[0], args[1]);
401
402 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
403 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
404 args[0], args[1]);
405
406 // tosa::EqualOp
407 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
408 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
409 args[0], args[1]);
410
411 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
412 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
413 args[0], args[1]);
414
415 // tosa::SelectOp
416 if (isa<tosa::SelectOp>(op)) {
417 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
418 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
419 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
420 }
421
422 // tosa::MaximumOp
423 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
424 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
425 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
426 rewriter, args[0], args[1], max);
427 }
428
429 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
430 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
431 }
432
433 // tosa::MinimumOp
434 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
435 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
436 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
437 rewriter, args[0], args[1], min);
438 }
439
440 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
441 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
442 }
443
444 // tosa::CeilOp
445 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
446 return math::CeilOp::create(rewriter, loc, resultTypes, args);
447
448 // tosa::FloorOp
449 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
450 return math::FloorOp::create(rewriter, loc, resultTypes, args);
451
452 // tosa::ClampOp
453 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
454 bool losesInfo = false;
455 APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
456 APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
457 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
458 APFloat::rmNearestTiesToEven, &losesInfo);
459 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
460 APFloat::rmNearestTiesToEven, &losesInfo);
461 auto min = arith::ConstantOp::create(
462 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
463 auto max = arith::ConstantOp::create(
464 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
465 auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
466
467 auto clampOp = llvm::cast<tosa::ClampOp>(op);
468 const auto nanMode = clampOp.getNanMode();
469
470 // NaN propagation has no meaning for non floating point types.
471 if (!isa<FloatType>(elementTy))
472 return result;
473
474 // In the case of "PROPAGATE" semantics no compare and selection is
475 // required.
476 if (nanMode == NanPropagationMode::PROPAGATE)
477 return result;
478
479 // In the case of "IGNORE" semantics materialize a comparison
480 // of the current operand to the reduction which will return true for a NaN
481 // argument and then selects between the initial reduction value and the
482 // calculated result based on whether the argument is NaN or not. In pseudo
483 // code:
484 //
485 // reduce<op>(x, init):
486 // result = op(init, x)
487 // return init if x == NaN else result
488
489 // Unordered comparison of NaN against itself will always return true.
490 Value isNaN = arith::CmpFOp::create(
491 rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
492 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
493 // is NaN.
494 return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result);
495 }
496
497 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
498 auto intTy = cast<IntegerType>(elementTy);
499 int64_t min =
500 cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
501 int64_t max =
502 cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
503
504 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
505 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
506 if (intTy.isUnsignedInteger()) {
507 minRepresentable = 0;
508 if (intTy.getIntOrFloatBitWidth() <= 63) {
509 maxRepresentable =
510 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
511 .getZExtValue();
512 }
513 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
514 // Ensure that min & max fit into signed n-bit constants.
515 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
516 .getSExtValue();
517 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
518 .getSExtValue();
519 }
520 // Ensure that the bounds are representable as n-bit signed/unsigned
521 // integers.
522 min = std::max(min, minRepresentable);
523 max = std::max(max, minRepresentable);
524 min = std::min(min, maxRepresentable);
525 max = std::min(max, maxRepresentable);
526
527 auto minVal = arith::ConstantIntOp::create(rewriter, loc, min,
528 intTy.getIntOrFloatBitWidth());
529 auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max,
530 intTy.getIntOrFloatBitWidth());
531 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
532 intTy.isUnsignedInteger());
533 }
534
535 // tosa::SigmoidOp
536 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
537 auto one =
538 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
539 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
540 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
541 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
542 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
543 }
544
545 // tosa::CastOp
546 if (isa<tosa::CastOp>(op)) {
547 Type srcTy = elementTy;
548 Type dstTy = resultTypes.front();
549 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
550 (void)rewriter.notifyMatchFailure(op, "unsupported type");
551 return nullptr;
552 }
553
554 bool bitExtend =
556
557 if (srcTy == dstTy)
558 return args.front();
559
560 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
561 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
563
564 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
565 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
567
568 // 1-bit integers need to be treated as signless.
569 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
570 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
572
573 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
574 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
576
577 // Unsigned integers need an unrealized cast so that they can be passed
578 // to UIToFP.
579 if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
580 auto unrealizedCast =
581 UnrealizedConversionCastOp::create(
582 rewriter, loc,
583 rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
584 .getResult(0);
585 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
586 unrealizedCast);
587 }
588
589 // All other si-to-fp conversions should be handled by SIToFP.
590 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
591 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
593
594 // Casting to boolean, floats need to only be checked as not-equal to zero.
595 if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
596 Value zero = arith::ConstantOp::create(rewriter, loc,
597 rewriter.getFloatAttr(srcTy, 0.0));
598 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
599 args.front(), zero);
600 }
601
602 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
603 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
604
605 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
606 // Check whether neither int min nor int max can be represented in the
607 // input floating-point type due to too short exponent range.
608 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
609 APFloat::semanticsMaxExponent(fltSemantics)) {
610 // Use cmp + select to replace infinites by int min / int max. Other
611 // integral values can be represented in the integer space.
612 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
613 auto posInf = arith::ConstantOp::create(
614 rewriter, loc,
615 rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
616 APFloat::getInf(fltSemantics)));
617 auto negInf = arith::ConstantOp::create(
618 rewriter, loc,
619 rewriter.getFloatAttr(
621 APFloat::getInf(fltSemantics, /*Negative=*/true)));
622 auto overflow = arith::CmpFOp::create(
623 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
624 auto underflow = arith::CmpFOp::create(
625 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
626 auto intMin = arith::ConstantOp::create(
627 rewriter, loc,
628 rewriter.getIntegerAttr(
630 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
631 auto intMax = arith::ConstantOp::create(
632 rewriter, loc,
633 rewriter.getIntegerAttr(
635 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
636 auto maxClamped =
637 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
638 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
639 maxClamped);
640 }
641
642 auto intMinFP = arith::ConstantOp::create(
643 rewriter, loc,
644 rewriter.getFloatAttr(
646 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
647 .getSExtValue()));
648
649 // Check whether the mantissa has enough bits to represent int max.
650 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
651 dstTy.getIntOrFloatBitWidth() - 1) {
652 // Int min can also be represented since it is a power of two and thus
653 // consists of a single leading bit. Therefore we can clamp the input
654 // in the floating-point domain.
655
656 auto intMaxFP = arith::ConstantOp::create(
657 rewriter, loc,
658 rewriter.getFloatAttr(
660 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
661 .getSExtValue()));
662
663 Value clamped =
664 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
665 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
666 }
667
668 // Due to earlier check we know exponant range is big enough to represent
669 // int min. We can therefore rely on int max + 1 being representable as
670 // well because it's just int min with a positive sign. So clamp the min
671 // value and compare against that to select the max int value if needed.
672 auto intMaxPlusOneFP = arith::ConstantOp::create(
673 rewriter, loc,
674 rewriter.getFloatAttr(
676 static_cast<double>(
677 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
678 .getSExtValue()) +
679 1.0f));
680
681 auto intMax = arith::ConstantOp::create(
682 rewriter, loc,
683 rewriter.getIntegerAttr(
685 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
686 auto minClampedFP =
687 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
688 auto minClamped =
689 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
690 auto overflow = arith::CmpFOp::create(
691 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
692 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
693 minClamped);
694 }
695
696 // Casting to boolean, integers need to only be checked as not-equal to
697 // zero.
698 if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
699 Value zero = arith::ConstantIntOp::create(rewriter, loc, 0,
700 srcTy.getIntOrFloatBitWidth());
701 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
702 args.front(), zero);
703 }
704
705 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
706 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
708
709 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
710 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
711 }
712 }
713
714 (void)rewriter.notifyMatchFailure(
715 op, "unhandled op for linalg body calculation for elementwise op");
716 return nullptr;
717}
718
720
721// Emit an 'arith.constant' op for the given index if it has not been created
722// yet, or return an existing constant. This will prevent an excessive creation
723// of redundant constants, easing readability of emitted code for unit tests.
725 IndexPool &indexPool, int64_t index) {
726 auto [it, inserted] = indexPool.try_emplace(index);
727 if (inserted)
728 it->second =
729 arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index));
730 return it->second;
731}
732
734 IndexPool &indexPool, Value tensor, int64_t index) {
735 auto indexValue = createIndex(rewriter, loc, indexPool, index);
736 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
737}
738
740 IndexPool &indexPool, Value tensor,
741 int64_t index) {
742 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
743 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
744 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
745 if (shapedType.isDynamicDim(index))
746 return getTensorDim(rewriter, loc, indexPool, tensor, index);
747 return rewriter.getIndexAttr(shapedType.getDimSize(index));
748}
749
750static bool operandsAndResultsRanked(Operation *operation) {
751 auto isRanked = [](Value value) {
752 return isa<RankedTensorType>(value.getType());
753 };
754 return llvm::all_of(operation->getOperands(), isRanked) &&
755 llvm::all_of(operation->getResults(), isRanked);
756}
757
758// Compute the runtime dimension size for dimension 'dim' of the output by
759// inspecting input 'operands', all of which are expected to have the same rank.
760// This function returns a pair {targetSize, masterOperand}.
761//
762// The runtime size of the output dimension is returned either as a statically
763// computed attribute or as a runtime SSA value.
764//
765// If the target size was inferred directly from one dominating operand, that
766// operand is returned in 'masterOperand'. If the target size is inferred from
767// multiple operands, 'masterOperand' is set to nullptr.
768static std::pair<OpFoldResult, Value>
770 ValueRange operands, int64_t dim) {
771 // If any input operand contains a static size greater than 1 for this
772 // dimension, that is the target size. An occurrence of an additional static
773 // dimension greater than 1 with a different value is undefined behavior.
774 for (auto operand : operands) {
775 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
776 if (ShapedType::isStatic(size) && size > 1)
777 return {rewriter.getIndexAttr(size), operand};
778 }
779
780 // Filter operands with dynamic dimension
781 auto operandsWithDynamicDim =
782 llvm::filter_to_vector(operands, [&](Value operand) {
783 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
784 });
785
786 // If no operand has a dynamic dimension, it means all sizes were 1
787 if (operandsWithDynamicDim.empty())
788 return {rewriter.getIndexAttr(1), operands.front()};
789
790 // Emit code that computes the runtime size for this dimension. If there is
791 // only one operand with a dynamic dimension, it is considered the master
792 // operand that determines the runtime size of the output dimension.
793 auto targetSize =
794 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
795 if (operandsWithDynamicDim.size() == 1)
796 return {targetSize, operandsWithDynamicDim[0]};
797
798 // Calculate maximum size among all dynamic dimensions
799 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
800 auto nextSize =
801 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
802 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
803 }
804 return {targetSize, nullptr};
805}
806
807// Compute the runtime output size for all dimensions. This function returns
808// a pair {targetShape, masterOperands}.
809static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
811 IndexPool &indexPool, ValueRange operands) {
812 assert(!operands.empty());
813 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
814 SmallVector<OpFoldResult> targetShape;
815 SmallVector<Value> masterOperands;
816 for (auto dim : llvm::seq<int64_t>(0, rank)) {
817 auto [targetSize, masterOperand] =
818 computeTargetSize(rewriter, loc, indexPool, operands, dim);
819 targetShape.push_back(targetSize);
820 masterOperands.push_back(masterOperand);
821 }
822 return {targetShape, masterOperands};
823}
824
826 IndexPool &indexPool, Value operand,
827 int64_t dim, OpFoldResult targetSize,
828 Value masterOperand) {
829 // Nothing to do if this is a static dimension
830 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
831 if (!rankedTensorType.isDynamicDim(dim))
832 return operand;
833
834 // If the target size for this dimension was directly inferred by only taking
835 // this operand into account, there is no need to broadcast. This is an
836 // optimization that will prevent redundant control flow, and constitutes the
837 // main motivation for tracking "master operands".
838 if (operand == masterOperand)
839 return operand;
840
841 // Affine maps for 'linalg.generic' op
842 auto rank = rankedTensorType.getRank();
843 SmallVector<AffineExpr> affineExprs;
844 for (auto index : llvm::seq<int64_t>(0, rank)) {
845 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
846 : rewriter.getAffineDimExpr(index);
847 affineExprs.push_back(affineExpr);
848 }
849 auto broadcastAffineMap =
850 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
851 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank);
852 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
853
854 // Check if broadcast is necessary
855 auto one = createIndex(rewriter, loc, indexPool, 1);
856 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim);
857 auto broadcastNecessary = arith::CmpIOp::create(
858 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
859
860 // Emit 'then' region of 'scf.if'
861 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
862 // It is not safe to cache constants across regions.
863 // New constants could potentially violate dominance requirements.
864 IndexPool localPool;
865
866 // Emit 'tensor.empty' op
867 SmallVector<OpFoldResult> outputTensorShape;
868 for (auto index : llvm::seq<int64_t>(0, rank)) {
869 auto size = index == dim ? targetSize
870 : getOrFoldTensorDim(rewriter, loc, localPool,
871 operand, index);
872 outputTensorShape.push_back(size);
873 }
874 Value outputTensor = tensor::EmptyOp::create(
875 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
876
877 // Emit 'linalg.generic' op
878 auto resultTensor =
879 linalg::GenericOp::create(
880 opBuilder, loc, outputTensor.getType(), operand, outputTensor,
881 affineMaps, getNParallelLoopsAttrs(rank),
882 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
883 // Emit 'linalg.yield' op
884 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
885 })
886 .getResult(0);
887
888 // Cast to original operand type if necessary
889 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
890 loc, operand.getType(), resultTensor);
891
892 // Emit 'scf.yield' op
893 scf::YieldOp::create(opBuilder, loc, castResultTensor);
894 };
895
896 // Emit 'else' region of 'scf.if'
897 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
898 scf::YieldOp::create(opBuilder, loc, operand);
899 };
900
901 // Emit 'scf.if' op
902 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
903 emitThenRegion, emitElseRegion);
904 return ifOp.getResult(0);
905}
906
908 IndexPool &indexPool, Value operand,
909 ArrayRef<OpFoldResult> targetShape,
910 ArrayRef<Value> masterOperands) {
911 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
912 assert((int64_t)targetShape.size() == rank);
913 assert((int64_t)masterOperands.size() == rank);
914 for (auto index : llvm::seq<int64_t>(0, rank))
915 operand =
916 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
917 targetShape[index], masterOperands[index]);
918 return operand;
919}
920
923 IndexPool &indexPool, ValueRange operands,
924 ArrayRef<OpFoldResult> targetShape,
925 ArrayRef<Value> masterOperands) {
926 // No need to broadcast for unary operations
927 if (operands.size() == 1)
928 return operands;
929
930 // No need to broadcast for static shape
931 bool hasDynamic = false;
932 for (auto op : operands) {
933 const auto tType = dyn_cast<RankedTensorType>(op.getType());
934 if (tType && !tType.hasStaticShape()) {
935 hasDynamic = true;
936 break;
937 }
938 }
939 if (!hasDynamic)
940 return operands;
941
942 // Broadcast dynamic dimensions operand by operand
943 return llvm::map_to_vector(operands, [&](Value operand) {
944 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
945 targetShape, masterOperands);
946 });
947}
948
949static LogicalResult
950emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
951 Operation *operation, ValueRange operands,
952 ArrayRef<OpFoldResult> targetShape,
953 const TypeConverter &converter) {
954 // Generate output tensor
955 auto resultType = cast_or_null<RankedTensorType>(
956 converter.convertType(operation->getResultTypes().front()));
957 if (!resultType) {
958 return rewriter.notifyMatchFailure(operation, "failed to convert type");
959 }
960 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
961 resultType.getElementType());
962
963 // Create affine maps. Input affine maps broadcast static dimensions of size
964 // 1. The output affine map is an identity map.
965 //
966 auto rank = resultType.getRank();
967 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
968 auto shape = cast<ShapedType>(operand.getType()).getShape();
969 SmallVector<AffineExpr> affineExprs;
970 for (auto it : llvm::enumerate(shape)) {
971 // Prefer producting identity maps whenever possible (i.e. no broadcasting
972 // needed) because some transforms (like reshape folding)
973 // do not support affine constant exprs.
974 bool requiresBroadcast =
975 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
976 auto affineExpr = requiresBroadcast
977 ? rewriter.getAffineConstantExpr(0)
978 : rewriter.getAffineDimExpr(it.index());
979 affineExprs.push_back(affineExpr);
980 }
981 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
982 });
983 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
984
985 // Emit 'linalg.generic' op
986 bool encounteredError = false;
987 auto linalgOp = linalg::GenericOp::create(
988 rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps,
990 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
992 operation, blockArgs.take_front(operation->getNumOperands()),
993 {resultType.getElementType()}, rewriter);
994 if (!opResult) {
995 encounteredError = true;
996 return;
997 }
998 linalg::YieldOp::create(opBuilder, loc, opResult);
999 });
1000 if (encounteredError)
1001 return rewriter.notifyMatchFailure(
1002 operation, "unable to create linalg.generic body for elementwise op");
1003
1004 // Cast 'linalg.generic' result into original result type if needed
1005 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1006 loc, resultType, linalgOp->getResult(0));
1007 rewriter.replaceOp(operation, castResult);
1008 return success();
1009}
1010
1012 ValueRange operands) {
1013 // Shift cannot broadcast
1014 if (isa<tosa::MulOp>(operation)) {
1015 DenseElementsAttr shiftElems;
1016 // Shift cannot broadcast when it is constant
1017 if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1018 return operands.take_front(2);
1019 else
1020 return operands.take_front(3);
1021 }
1022 if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1023 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1024 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1025 if (failed(maybeOutZp) && failed(maybeInZp))
1026 return operands;
1027 // Input1_zp and output_zp cannot broadcast when they are constants.
1028 return operands.take_front(1);
1029 }
1030 return operands;
1031}
1032
1033static LogicalResult
1035 ConversionPatternRewriter &rewriter,
1036 const TypeConverter &converter) {
1037
1038 // Collect op properties
1039 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1040 assert(operation->getNumOperands() >= 1 &&
1041 "elementwise op expects at least 1 operand");
1042 if (!operandsAndResultsRanked(operation))
1043 return rewriter.notifyMatchFailure(operation,
1044 "Unranked tensors not supported");
1045
1046 // Lower operation
1047 IndexPool indexPool;
1048 auto loc = operation->getLoc();
1049 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1050 auto [targetShape, masterOperands] =
1051 computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
1052 auto broadcastOperands =
1053 broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
1054 targetShape, masterOperands);
1055 return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
1056 targetShape, converter);
1057}
1058
1059// Returns the constant initial value for a given reduction operation. The
1060// attribute type varies depending on the element type required.
1061static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1062 PatternRewriter &rewriter) {
1063 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1064 return rewriter.getFloatAttr(elementTy, 0.0);
1065
1066 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1067 return rewriter.getIntegerAttr(elementTy, 0);
1068
1069 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1070 return rewriter.getFloatAttr(elementTy, 1.0);
1071
1072 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1073 return rewriter.getIntegerAttr(elementTy, 1);
1074
1075 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1076 return rewriter.getFloatAttr(
1077 elementTy, APFloat::getLargest(
1078 cast<FloatType>(elementTy).getFloatSemantics(), false));
1079
1080 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1081 return rewriter.getIntegerAttr(
1082 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
1083
1084 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1085 return rewriter.getFloatAttr(
1086 elementTy, APFloat::getLargest(
1087 cast<FloatType>(elementTy).getFloatSemantics(), true));
1088
1089 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1090 return rewriter.getIntegerAttr(
1091 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1092
1093 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1094 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
1095
1096 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1097 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
1098
1099 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1100 return rewriter.getFloatAttr(
1101 elementTy, APFloat::getLargest(
1102 cast<FloatType>(elementTy).getFloatSemantics(), true));
1103
1104 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1105 return rewriter.getIntegerAttr(
1106 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
1107
1108 return {};
1109}
1110
1111// Creates the body calculation for a reduction. The operations vary depending
1112// on the input type.
1114 ValueRange args,
1115 Type elementTy,
1116 PatternRewriter &rewriter) {
1117 Location loc = op->getLoc();
1118 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1119 return arith::AddFOp::create(rewriter, loc, args);
1120 }
1121
1122 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1123 return arith::AddIOp::create(rewriter, loc, args);
1124 }
1125
1126 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1127 return arith::MulFOp::create(rewriter, loc, args);
1128 }
1129
1130 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1131 return arith::MulIOp::create(rewriter, loc, args);
1132 }
1133
1134 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1135 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1136 }
1137
1138 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1139 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1140 }
1141
1142 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1143 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1144 }
1145
1146 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1147 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1148 }
1149
1150 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1151 return arith::AndIOp::create(rewriter, loc, args);
1152
1153 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1154 return arith::OrIOp::create(rewriter, loc, args);
1155
1156 return {};
1157}
1158
1159// Performs the match and rewrite for reduction operations. This includes
1160// declaring a correctly sized initial value, and the linalg.generic operation
1161// that reduces across the specified axis.
1162template <typename OpTy>
1163static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1164 PatternRewriter &rewriter) {
1165 auto loc = op->getLoc();
1166 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1167 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1168 if (!inputTy || !resultTy)
1169 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1170
1171 auto elementTy = resultTy.getElementType();
1172 Value input = op->getOperand(0);
1173
1174 // Figure out the accType if needed
1175 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1176 isa<FloatType>(elementTy) &&
1177 cast<FloatType>(elementTy).isBF16();
1178 Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
1179
1180 SmallVector<int64_t> reduceShape;
1181 SmallVector<Value> dynDims;
1182 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1183 if (axis != i) {
1184 reduceShape.push_back(inputTy.getDimSize(i));
1185 if (inputTy.isDynamicDim(i))
1186 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1187 }
1188 }
1189
1190 SmallVector<Value> inputs, outputs;
1191 inputs.push_back(input);
1192
1193 // First fill the output buffer with the init value.
1194 auto emptyTensor =
1195 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1196 .getResult();
1197
1198 auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
1199 if (!fillValueAttr)
1200 return rewriter.notifyMatchFailure(
1201 op, "No initial value found for reduction operation");
1202
1203 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1204 auto filledTensor =
1205 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
1206 ValueRange{emptyTensor})
1207 .result();
1208 outputs.push_back(filledTensor);
1209
1210 bool isNanIgnoreMode = false;
1211 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1212 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1213 // NaN propagation has no meaning for non floating point types.
1214 if (isa<FloatType>(elementTy) &&
1215 op.getNanMode() == NanPropagationMode::IGNORE) {
1216 isNanIgnoreMode = true;
1217 // Because the TOSA spec requires the result be NaN iff all elements in
1218 // the reduction are NaN we can't simply perform a compare and select.
1219 // Additionally we have to keep track of whether we've seen any non-NaN
1220 // values and then do a final select based on this predicate.
1221 auto trueAttr = rewriter.getBoolAttr(true);
1222 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1223 auto emptyBoolTensor =
1224 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1225 trueValue.getType(), dynDims)
1226 .getResult();
1227 auto allResultsNaNTensor =
1228 linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
1229 ValueRange{emptyBoolTensor})
1230 .result();
1231 // Note that because the linalg::ReduceOp has two variadic arguments
1232 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1233 // need to have the same number of inputs and outputs.
1234 //
1235 // The second input isn't actually used anywhere since the value used to
1236 // update the NaN flag is calculated inside the body of the reduction and
1237 // then used to update an out value.
1238 // In order to satisfy type constraints we just pass another copy of the
1239 // input here.
1240 inputs.push_back(input);
1241 outputs.push_back(allResultsNaNTensor);
1242 }
1243 }
1244
1245 bool didEncounterError = false;
1246 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1247 rewriter, loc, inputs, outputs, axis,
1248 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1249 std::array<Value, 2> binaryArgs{
1250 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1251
1252 // If reduction type differs then extend (applicable to reduce_sum)
1253 if (binaryArgs[0].getType() != accTy)
1254 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1255 binaryArgs[0]);
1256
1257 auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
1258 accTy, rewriter);
1259 if (result)
1260 didEncounterError = true;
1261
1262 SmallVector<Value> resultsToYield;
1263 if (isNanIgnoreMode) {
1264 auto inputValue = blockArgs[0];
1265 auto initialValue = blockArgs[2];
1266 auto oldAllResultsNanFlagValue = blockArgs[3];
1267
1268 // Unordered comparison of NaN against itself will always return true.
1269 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1270 arith::CmpFPredicate::UNO,
1271 inputValue, inputValue);
1272 // If we've encountered a NaN, take the non-NaN value.
1273 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1274 isNaN, initialValue, result);
1275 // Update the flag which keeps track of whether we have seen a non-NaN
1276 // value.
1277 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1278 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1279 resultsToYield.push_back(selectOp);
1280 resultsToYield.push_back(newAllResultsNanFlagValue);
1281 } else {
1282 resultsToYield.push_back(result);
1283 }
1284 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1285 });
1286
1287 if (!didEncounterError)
1288 return rewriter.notifyMatchFailure(
1289 op, "unable to create linalg.generic body for reduce op");
1290
1291 if (isNanIgnoreMode) {
1292 // Materialize a check to see whether we encountered any non-NaN values, if
1293 // we didn't we need to select a tensor of NaNs since the result will just
1294 // be the initial identity value propagated through all the compares and
1295 // selects inside the reduction.
1296
1297 // Create a tensor full of NaNs.
1298 auto nanValueAttr = rewriter.getFloatAttr(
1299 accTy,
1300 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
1301 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1302 auto emptyNanTensor =
1303 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1304 .getResult();
1305 auto nanFilledTensor =
1306 linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
1307 ValueRange{emptyNanTensor})
1308 .result();
1309
1310 // Create an empty tensor, non need to fill this since it will be
1311 // overwritten by the select.
1312 auto finalEmptyTensor =
1313 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1314 .getResult();
1315
1316 // Do a selection between the tensors akin to:
1317 // result = NaN if "all results NaN" else result.
1318 SmallVector<Value> ins, outs;
1319 ins.push_back(linalgOp->getOpResult(1));
1320 ins.push_back(nanFilledTensor);
1321 ins.push_back(linalgOp->getResult(0));
1322 outs.push_back(finalEmptyTensor);
1323 auto linalgSelect =
1324 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1325 linalgOp = linalgSelect;
1326 }
1327
1328 // Truncate back to resultTy if needed
1329 Value reducedRes = linalgOp->getResult(0);
1330 if (widenAccTy) {
1331 auto resEmptyOp =
1332 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1333 .getResult();
1334
1335 const unsigned reducedRank =
1336 cast<ShapedType>(reducedRes.getType()).getRank();
1337 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
1338 reducedRes =
1339 linalg::GenericOp::create(
1340 rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
1341 ValueRange{resEmptyOp},
1342 ArrayRef<AffineMap>{identityMap, identityMap},
1343 getNParallelLoopsAttrs(reducedRank),
1344 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1345 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1346 elementTy, args[0]);
1347 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1348 })
1349 .getResults()[0];
1350 }
1351
1352 SmallVector<ReassociationExprs, 4> reassociationMap;
1353 uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
1354 reassociationMap.resize(expandInputRank);
1355
1356 for (uint64_t i = 0; i < expandInputRank; i++) {
1357 int32_t dimToPush = i > axis ? i + 1 : i;
1358 reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush));
1359 }
1360
1361 if (expandInputRank != 0) {
1362 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1363 reassociationMap[expandedDim].push_back(
1364 rewriter.getAffineDimExpr(expandedDim + 1));
1365 }
1366
1367 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1368 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1369 // not have access to such information. This matters when handling dynamically
1370 // sized tensors.
1371 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1372 reassociationMap);
1373 return success();
1374}
1375
1376namespace {
1377
1378template <typename SrcOp>
1379class PointwiseConverter : public OpConversionPattern<SrcOp> {
1380public:
1381 using OpConversionPattern<SrcOp>::OpConversionPattern;
1382 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1383
1384 LogicalResult
1385 matchAndRewrite(SrcOp op, OpAdaptor operands,
1386 ConversionPatternRewriter &rewriter) const final {
1388 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1389 }
1390};
1391
1392// Collapse tensor<1xiN> into tensor<iN>
1393// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1394static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1395 Location loc) {
1397 // Create the collapsed type
1398 auto inputType = cast<RankedTensorType>(input.getType());
1399 auto elemType = inputType.getElementType();
1400 auto collapsedType = RankedTensorType::get({}, elemType);
1401 // Emit the collapse op
1402 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1403 reassociation);
1404}
1405
1407convertToI8(const llvm::SmallVector<int32_t> &input) {
1409 output.reserve(input.size());
1410
1411 for (auto v : llvm::map_range(
1412 input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1413 output.push_back(v);
1414 }
1415 return output;
1416}
1417
1418// The shift or multiplier may be either constant or non-constant, depending on
1419// whether dynamic extension is enabled.
1420// - If the shift or multiplier is non-constant, add it as an input to
1421// linalg::GenericOp by:
1422// 1. Pushing it into 'genericInputs'.
1423// 2. Appending a corresponding affine map to 'indexingMaps'.
1424// - If the shift or multiplier is constant, set 'constant' instead.
1425static void setupLinalgGenericOpInputAndIndexingMap(
1427 SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1428 bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1429 bool isShift = false) {
1430
1431 auto loc = op.getLoc();
1432 auto inputTy = cast<ShapedType>(op.getInput().getType());
1433 unsigned rank = inputTy.getRank();
1434 SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1435
1436 if (isConstant) {
1437 // If we are rescaling per-channel then we need to store the
1438 // values in a buffer.
1439 if (values.size() == 1) {
1440 IntegerAttr intAttr = isShift
1441 ? rewriter.getI8IntegerAttr(values.front())
1442 : rewriter.getI32IntegerAttr(values.front());
1443 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1444 } else {
1445 auto elementType =
1446 isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1447 auto tensorType = RankedTensorType::get(
1448 {static_cast<int64_t>(values.size())}, elementType);
1449 DenseIntElementsAttr EltAttr;
1450 if (isShift)
1451 EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1452 else
1453 EltAttr = DenseIntElementsAttr::get(tensorType, values);
1454 genericInputs.push_back(
1455 arith::ConstantOp::create(rewriter, loc, EltAttr));
1456 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1457 /*symbolCount=*/0, exprs,
1458 rewriter.getContext()));
1459 }
1460 } else {
1461 // If we are not rescaling per-channel then we need to collapse 1xN to N
1462 // and push broadcastMap.
1463 auto operand = isShift ? op.getShift() : op.getMultiplier();
1464 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1465 if (tensorType && tensorType.hasStaticShape() &&
1466 tensorType.getShape()[0] == 1) {
1467 // broadcastMap = affine_map<(d0, d1) -> ()>
1468 // It would affect as broadcast for scalar values in linalg::GenericOp.
1469 AffineMap broadcastMap =
1470 AffineMap::get(rank, 0, {}, rewriter.getContext());
1471 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1472 indexingMaps.push_back(broadcastMap);
1473 } else {
1474 genericInputs.push_back(operand);
1475 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1476 /*symbolCount=*/0, exprs,
1477 rewriter.getContext()));
1478 }
1479 }
1480 arg = indexingMaps.size() - 1;
1481}
1482
1483// Return the extended Zp to be used in subsequent arithmetic operations.
1484static Value getExtendZp(OpBuilder &builder, Type valueTy,
1485 FailureOr<int64_t> maybeZp, Location loc,
1486 ValueRange blockArgs, int64_t zpArg,
1487 bool isOutputZp = false) {
1488 Value result;
1489 const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1490 const uint32_t attrBitwidth =
1491 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1492 auto extendType = builder.getIntegerType(attrBitwidth);
1493 // The Zp value can be either constant or non-constant, depending on
1494 // whether dynamic extension is enabled.
1495 // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1496 // be passed as an input to linalg::GenericOp.
1497 if (failed(maybeZp)) {
1498 result = blockArgs[zpArg];
1499 auto zpTy = result.getType();
1500 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1501 // For ExtUIOp, the input must be signless.
1502 // UnrealizedConversionCastOp will cast the input to signless type.
1503 if (zpTy.isUnsignedInteger()) {
1504 result =
1505 UnrealizedConversionCastOp::create(
1506 builder, loc,
1507 builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1508 .getResult(0);
1509 }
1510 if (zpTy.isUnsignedInteger()) {
1511 return arith::ExtUIOp::create(builder, loc, extendType, result);
1512 } else {
1513 return arith::ExtSIOp::create(builder, loc, extendType, result);
1514 }
1515 }
1516 } else {
1517 return arith::ConstantOp::create(builder, loc,
1518 IntegerAttr::get(extendType, *maybeZp));
1519 }
1520 return result;
1521}
1522
1523class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1524public:
1525 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1526
1527 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1528 PatternRewriter &rewriter) const final {
1529 auto loc = op.getLoc();
1530 auto input = op.getInput();
1531 auto inputTy = cast<ShapedType>(op.getInput().getType());
1532 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1533 unsigned rank = inputTy.getRank();
1534
1535 // This is an illegal configuration. terminate and log an error
1536 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1537 return rewriter.notifyMatchFailure(
1538 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1539 "currently supported");
1540 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1541 return rewriter.notifyMatchFailure(
1542 op, "tosa.rescale requires scale32 for double_round to be true");
1543
1544 if (!isa<IntegerType>(inputTy.getElementType()))
1545 return rewriter.notifyMatchFailure(op, "only support integer type");
1546
1547 SmallVector<Value> dynDims;
1548 for (int i = 0; i < outputTy.getRank(); i++) {
1549 if (outputTy.isDynamicDim(i)) {
1550 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1551 }
1552 }
1553
1554 DenseElementsAttr shiftElems;
1555 bool isShiftConstant = false;
1556 if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1557 isShiftConstant = true;
1558
1559 DenseElementsAttr multiplierElems;
1560 bool isMultiplierConstant = false;
1561 if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1562 isMultiplierConstant = true;
1563
1564 llvm::SmallVector<int32_t> shiftValues;
1565 llvm::SmallVector<int32_t> multiplierValues;
1566 bool doubleRound;
1567
1568 if (isMultiplierConstant && isShiftConstant) {
1569 // explicit cast is required here
1570 shiftValues = llvm::map_to_vector(
1571 shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1572 return static_cast<int32_t>(attr.getInt());
1573 });
1574 multiplierValues =
1575 llvm::map_to_vector(multiplierElems.getValues<IntegerAttr>(),
1576 [](IntegerAttr attr) -> int32_t {
1577 return static_cast<int32_t>(attr.getInt());
1578 });
1579
1580 // If we shift by more than the bitwidth, this just sets to 0.
1581 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1582 if (shiftValues[i] > 63) {
1583 shiftValues[i] = 0;
1584 multiplierValues[i] = 0;
1585 }
1586 }
1587 // Double round only occurs if shift is greater than 31, check that this
1588 // is ever true.
1589 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1590 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1591 } else
1592 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1593
1594 RoundingMode roundingMode =
1595 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1596
1597 SmallVector<AffineMap> indexingMaps = {
1598 rewriter.getMultiDimIdentityMap(rank)};
1599 SmallVector<Value, 4> genericInputs = {input};
1600
1601 // If we are rescaling per-channel then we need to store the multiplier
1602 // values in a buffer.
1603 Value multiplierConstant;
1604 int64_t multiplierArg = 0;
1605 setupLinalgGenericOpInputAndIndexingMap(
1606 rewriter, multiplierValues, genericInputs, indexingMaps,
1607 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1608
1609 // If we are rescaling per-channel then we need to store the shift
1610 // values in a buffer.
1611 Value shiftConstant;
1612 int64_t shiftArg = 0;
1613 setupLinalgGenericOpInputAndIndexingMap(
1614 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1615 shiftConstant, shiftArg, true);
1616
1617 // broadcastMap = affine_map<(d0, d1) -> ()>
1618 // It would affect as broadcast for scalar values in linalg::GenericOp.
1619 AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1620 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1621 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1622 // The inputZp and outputZp may be either constant or non-constant,
1623 // depending on whether dynamic extension is enabled.
1624 // - If the zp's are non-constant, add them as an inputs to
1625 // linalg::GenericOp by:
1626 // 1. Pushing it into 'genericInputs'.
1627 // 2. Appending a corresponding affine map to 'indexingMaps'.
1628 // - If the zp's are constant, they would be generated as arith.constant.
1629 int64_t iZpArg = 0;
1630 if (failed(maybeIZp)) {
1631 genericInputs.push_back(
1632 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1633 indexingMaps.push_back(broadcastMap);
1634 iZpArg = indexingMaps.size() - 1;
1635 }
1636 int64_t oZpArg = 0;
1637 if (failed(maybeOZp)) {
1638 genericInputs.push_back(
1639 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1640 indexingMaps.push_back(broadcastMap);
1641 oZpArg = indexingMaps.size() - 1;
1642 }
1643
1644 // Indexing maps for output values.
1645 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1646
1647 // Construct the indexing maps needed for linalg.generic ops.
1648 Value emptyTensor = tensor::EmptyOp::create(
1649 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1650 ArrayRef<Value>({dynDims}));
1651
1652 auto linalgOp = linalg::GenericOp::create(
1653 rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor},
1654 indexingMaps, getNParallelLoopsAttrs(rank),
1655 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1656 ValueRange blockArgs) {
1657 Value value = blockArgs[0];
1658 Type valueTy = value.getType();
1659
1660 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1661 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1662 nestedLoc, blockArgs, iZpArg);
1663
1664 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1665 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1666 nestedLoc, blockArgs, oZpArg, true);
1667
1668 IntegerType outIntType =
1669 cast<IntegerType>(blockArgs.back().getType());
1670 unsigned outBitWidth = outIntType.getWidth();
1671 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1672
1673 Value multiplier = multiplierConstant ? multiplierConstant
1674 : blockArgs[multiplierArg];
1675 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1676
1677 if (valueTy.isUnsignedInteger()) {
1678 value = UnrealizedConversionCastOp::create(
1679 nestedBuilder, nestedLoc,
1680 nestedBuilder.getIntegerType(
1681 valueTy.getIntOrFloatBitWidth()),
1682 value)
1683 .getResult(0);
1684 }
1685 if (valueTy.getIntOrFloatBitWidth() < 32) {
1686 if (op.getInputUnsigned()) {
1687 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1688 nestedBuilder.getI32Type(), value);
1689 } else {
1690 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1691 nestedBuilder.getI32Type(), value);
1692 }
1693 }
1694
1695 value =
1696 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1697
1698 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1699 nestedBuilder.getI32Type(), value,
1700 multiplier, shift, roundingMode);
1701
1702 // Move to the new zero-point.
1703 value =
1704 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1705
1706 // Saturate to the output size.
1707 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1708 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1709
1710 // Unsigned integers have a difference output value.
1711 if (op.getOutputUnsigned()) {
1712 intMin = 0;
1713 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1714 }
1715
1716 auto intMinVal = arith::ConstantOp::create(
1717 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1718 auto intMaxVal = arith::ConstantOp::create(
1719 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1720
1721 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1722 nestedBuilder, /*isUnsigned=*/false);
1723
1724 if (outIntType.getWidth() < 32) {
1725 value = arith::TruncIOp::create(
1726 nestedBuilder, nestedLoc,
1727 rewriter.getIntegerType(outIntType.getWidth()), value);
1728 }
1729
1730 if (outIntType.isUnsignedInteger()) {
1731 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1732 outIntType, value)
1733 .getResult(0);
1734 }
1735 linalg::YieldOp::create(nestedBuilder, loc, value);
1736 });
1737
1738 rewriter.replaceOp(op, linalgOp->getResults());
1739 return success();
1740 }
1741};
1742
1743// Handle the resize case where the input is a 1x1 image. This case
1744// can entirely avoiding having extract operations which target much
1745// more difficult to optimize away.
1746class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1747public:
1748 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1749
1750 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1751 PatternRewriter &rewriter) const final {
1752 Location loc = op.getLoc();
1753 ImplicitLocOpBuilder builder(loc, rewriter);
1754 auto input = op.getInput();
1755 auto inputTy = cast<RankedTensorType>(input.getType());
1756 auto resultTy = cast<RankedTensorType>(op.getType());
1757 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1758
1759 auto inputH = inputTy.getDimSize(1);
1760 auto inputW = inputTy.getDimSize(2);
1761 auto outputH = resultTy.getDimSize(1);
1762 auto outputW = resultTy.getDimSize(2);
1763
1764 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1765 return rewriter.notifyMatchFailure(
1766 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1767
1768 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1769 op.getMode() != ResizeMode::BILINEAR)
1770 return rewriter.notifyMatchFailure(
1771 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1772
1773 if (inputTy == resultTy) {
1774 rewriter.replaceOp(op, input);
1775 return success();
1776 }
1777
1778 SmallVector<int64_t> scale;
1779 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale)) {
1780 return failure();
1781 }
1782
1783 // Collapse the unit width and height away.
1784 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1785 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1786 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1787 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1788 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1789
1790 auto collapseTy =
1791 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1792 inputTy.getElementType());
1793 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1794 reassociationMap);
1795
1796 // Get any dynamic shapes that appear in the input format.
1797 llvm::SmallVector<Value> outputDynSize;
1798 if (inputTy.isDynamicDim(0))
1799 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1800 if (inputTy.isDynamicDim(3))
1801 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1802
1803 // Generate the elementwise operation for casting scaling the input value.
1804 auto genericTy = collapseTy.clone(resultTy.getElementType());
1805 Value empty =
1806 tensor::EmptyOp::create(builder, genericTy.getShape(),
1807 resultTy.getElementType(), outputDynSize);
1808 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1809 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1810 utils::IteratorType::parallel);
1811
1812 auto generic = linalg::GenericOp::create(
1813 builder, genericTy, ValueRange{collapse}, ValueRange{empty},
1814 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1815 [=](OpBuilder &b, Location loc, ValueRange args) {
1816 Value value = args[0];
1817 // This is the quantized case.
1818 if (inputTy.getElementType() != resultTy.getElementType()) {
1819 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1820 value);
1821
1822 if (isBilinear && scale[0] != 0) {
1823 Value scaleY = arith::ConstantOp::create(
1824 b, loc, b.getI32IntegerAttr(scale[0]));
1825 value = arith::MulIOp::create(b, loc, value, scaleY);
1826 }
1827
1828 if (isBilinear && scale[2] != 0) {
1829 Value scaleX = arith::ConstantOp::create(
1830 b, loc, b.getI32IntegerAttr(scale[2]));
1831 value = arith::MulIOp::create(b, loc, value, scaleX);
1832 }
1833 }
1834
1835 linalg::YieldOp::create(b, loc, value);
1836 });
1837
1838 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1839 op, resultTy, generic.getResults()[0], reassociationMap);
1840 return success();
1841 }
1842};
1843
1844// TOSA resize with width or height of 1 may be broadcasted to a wider
1845// dimension. This is done by materializing a new tosa.resize without
1846// the broadcasting behavior, and an explicit broadcast afterwards.
1847class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1848public:
1849 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1850
1851 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1852 PatternRewriter &rewriter) const final {
1853 Location loc = op.getLoc();
1854 ImplicitLocOpBuilder builder(loc, rewriter);
1855 auto input = op.getInput();
1856 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1857 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1858
1859 if (!inputTy || !resultTy)
1860 return rewriter.notifyMatchFailure(op,
1861 "requires ranked input/output types");
1862
1863 auto batch = inputTy.getDimSize(0);
1864 auto channels = inputTy.getDimSize(3);
1865 auto inputH = inputTy.getDimSize(1);
1866 auto inputW = inputTy.getDimSize(2);
1867 auto outputH = resultTy.getDimSize(1);
1868 auto outputW = resultTy.getDimSize(2);
1869
1870 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1871 return rewriter.notifyMatchFailure(
1872 op, "tosa.resize has no broadcasting behavior");
1873
1874 // For any dimension that is broadcastable we generate a width of 1
1875 // on the output.
1876 llvm::SmallVector<int64_t> resizeShape;
1877 resizeShape.push_back(batch);
1878 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1879 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1880 resizeShape.push_back(channels);
1881
1882 auto resizeTy = resultTy.clone(resizeShape);
1883 auto resize =
1884 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1885 op.getOffset(), op.getBorder(), op.getMode());
1886
1887 // Collapse an unit result dims.
1888 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1889 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1890 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1891 if (inputH != 1)
1892 reassociationMap.push_back({});
1893 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1894 if (inputW != 1)
1895 reassociationMap.push_back({});
1896 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1897
1898 llvm::SmallVector<int64_t> collapseShape = {batch};
1899 if (inputH != 1)
1900 collapseShape.push_back(outputH);
1901 if (inputW != 1)
1902 collapseShape.push_back(outputW);
1903 collapseShape.push_back(channels);
1904
1905 auto collapseTy = resultTy.clone(collapseShape);
1906 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1907 resize, reassociationMap);
1908
1909 // Broadcast the collapsed shape to the output result.
1910 llvm::SmallVector<Value> outputDynSize;
1911 if (inputTy.isDynamicDim(0))
1912 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1913 if (inputTy.isDynamicDim(3))
1914 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1915
1916 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1917 utils::IteratorType::parallel);
1918 Value empty = tensor::EmptyOp::create(
1919 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1920
1921 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)};
1922 if (inputH != 1)
1923 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1924 if (inputW != 1)
1925 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1926 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1927
1928 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1929 inputExprs, rewriter.getContext());
1930
1931 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1932 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1933 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1934 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1935 [=](OpBuilder &b, Location loc, ValueRange args) {
1936 Value value = args[0];
1937 linalg::YieldOp::create(b, loc, value);
1938 });
1939
1940 return success();
1941 }
1942};
1943
1944class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1945public:
1946 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1947
1948 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1949 PatternRewriter &rewriter) const final {
1950 Location loc = op.getLoc();
1951 ImplicitLocOpBuilder b(loc, rewriter);
1952 auto input = op.getInput();
1953 auto inputTy = cast<ShapedType>(input.getType());
1954 auto resultTy = cast<ShapedType>(op.getType());
1955 auto resultETy = resultTy.getElementType();
1956
1957 bool floatingPointMode = isa<FloatType>(resultETy);
1958 auto floatTy = resultETy;
1959
1960 auto imageH = inputTy.getShape()[1];
1961 auto imageW = inputTy.getShape()[2];
1962
1963 auto dynamicDimsOr =
1964 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1965 if (!dynamicDimsOr.has_value())
1966 return rewriter.notifyMatchFailure(
1967 op, "unable to get dynamic dimensions of tosa.resize");
1968
1969 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1970 op.getMode() != ResizeMode::BILINEAR)
1971 return rewriter.notifyMatchFailure(
1972 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1973
1974 SmallVector<AffineMap, 2> affineMaps = {
1975 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1976 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1977 resultETy, *dynamicDimsOr);
1978 auto genericOp = linalg::GenericOp::create(
1979 b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1980 getNParallelLoopsAttrs(resultTy.getRank()));
1981 Value resize = genericOp.getResult(0);
1982
1983 {
1984 OpBuilder::InsertionGuard regionGuard(b);
1985 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1986 TypeRange({resultETy}), loc);
1987 Value batch = linalg::IndexOp::create(b, 0);
1988 Value y = linalg::IndexOp::create(b, 1);
1989 Value x = linalg::IndexOp::create(b, 2);
1990 Value channel = linalg::IndexOp::create(b, 3);
1991
1992 Value zeroI32 =
1993 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1994 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1995 Value hMax =
1996 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1997 Value wMax =
1998 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1999
2000 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
2001 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
2002
2003 SmallVector<int64_t> scale, offset, border;
2004 if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
2005 !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
2006 !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
2007 return rewriter.notifyMatchFailure(
2008 op, "tosa.resize scale/offset/border should have compile time "
2009 "constant values.");
2010 }
2011
2012 Value yScaleN, yScaleD, xScaleN, xScaleD;
2013 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2014 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2015 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2016 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2017
2018 Value yOffset, xOffset, yBorder, xBorder;
2019 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2020 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2021 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2022 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2023
2024 // Compute the ix and dx values for both the X and Y dimensions.
2025 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2026 Value scaleN, Value scaleD, Value offset,
2027 int size, ImplicitLocOpBuilder &b) {
2028 if (size == 1) {
2029 index = zeroI32;
2030 delta = zeroFp;
2031 return;
2032 }
2033 // x = x * scale_d + offset;
2034 // ix = floor(x / scale_n)
2035 Value val = arith::MulIOp::create(b, in, scaleD);
2036 val = arith::AddIOp::create(b, val, offset);
2037 index = arith::FloorDivSIOp::create(b, val, scaleN);
2038
2039 // rx = x % scale_n
2040 // dx = rx / scale_n
2041 Value r = arith::RemSIOp::create(b, val, scaleN);
2042 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2043 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2044 delta = arith::DivFOp::create(b, rFp, scaleNfp);
2045 };
2046
2047 // Compute the ix and dx values for the X and Y dimensions - int case.
2048 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2049 Value scaleN, Value scaleD, Value offset,
2050 int size, ImplicitLocOpBuilder &b) {
2051 if (size == 1) {
2052 index = zeroI32;
2053 delta = zeroI32;
2054 return;
2055 }
2056 // x = x * scale_d + offset;
2057 // ix = floor(x / scale_n)
2058 // dx = x - ix * scale_n;
2059 Value val = arith::MulIOp::create(b, in, scaleD);
2060 val = arith::AddIOp::create(b, val, offset);
2061 index = arith::DivSIOp::create(b, val, scaleN);
2062 delta = arith::MulIOp::create(b, index, scaleN);
2063 delta = arith::SubIOp::create(b, val, delta);
2064 };
2065
2066 Value ix, iy, dx, dy;
2067 if (floatingPointMode) {
2068 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2069 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2070 } else {
2071 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2072 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2073 }
2074
2075 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2076 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2077
2078 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2079 Value max, int size,
2080 ImplicitLocOpBuilder &b) -> Value {
2081 if (size == 1) {
2083 }
2084
2085 Value pred;
2086 if (floatingPointMode) {
2087 auto h =
2088 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2089 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2090 } else {
2091 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2092 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2093 dvalDouble, scale);
2094 }
2095
2096 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2097 val = arith::AddIOp::create(b, val, offset);
2098 val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
2099 return arith::IndexCastOp::create(b, b.getIndexType(), val);
2100 };
2101
2102 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2103 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2104
2105 Value result = tensor::ExtractOp::create(
2106 b, input, ValueRange{batch, iy, ix, channel});
2107
2108 linalg::YieldOp::create(b, result);
2109 } else {
2110 // The mode here must be BILINEAR.
2111 assert(op.getMode() == ResizeMode::BILINEAR);
2112
2113 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2114
2115 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
2116 Value max, ImplicitLocOpBuilder &b) {
2117 val0 = in;
2118 val1 = arith::AddIOp::create(b, val0, oneVal);
2119 val0 =
2120 clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
2121 val1 =
2122 clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
2123 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2124 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2125 };
2126
2127 // Linalg equivalent to the section below:
2128 // int16_t iy0 = apply_max(iy, 0);
2129 // int16_t iy1 = apply_min(iy + 1, IH - 1);
2130 // int16_t ix0 = apply_max(ix, 0);
2131 // int16_t ix1 = apply_min(ix + 1, IW - 1);
2132 Value x0, x1, y0, y1;
2133 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2134 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2135
2136 Value y0x0 = tensor::ExtractOp::create(
2137 b, input, ValueRange{batch, y0, x0, channel});
2138 Value y0x1 = tensor::ExtractOp::create(
2139 b, input, ValueRange{batch, y0, x1, channel});
2140 Value y1x0 = tensor::ExtractOp::create(
2141 b, input, ValueRange{batch, y1, x0, channel});
2142 Value y1x1 = tensor::ExtractOp::create(
2143 b, input, ValueRange{batch, y1, x1, channel});
2144
2145 if (floatingPointMode) {
2146 auto oneVal =
2147 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2148 auto interpolate = [&](Value val0, Value val1, Value delta,
2149 int inputSize,
2150 ImplicitLocOpBuilder &b) -> Value {
2151 if (inputSize == 1)
2152 return val0;
2153 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2154 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2155 Value mul1 = arith::MulFOp::create(b, val1, delta);
2156 return arith::AddFOp::create(b, mul0, mul1);
2157 };
2158
2159 // Linalg equivalent to the section below:
2160 // topAcc = v00 * (unit_x - dx);
2161 // topAcc += v01 * dx;
2162 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2163
2164 // Linalg equivalent to the section below:
2165 // bottomAcc = v10 * (unit_x - dx);
2166 // bottomAcc += v11 * dx;
2167 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2168
2169 // Linalg equivalent to the section below:
2170 // result = topAcc * (unit_y - dy) + bottomAcc * dy
2171 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2172 linalg::YieldOp::create(b, result);
2173 } else {
2174 // Perform in quantized space.
2175 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2176 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2177 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2178 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2179
2180 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
2181 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2182 dx = arith::ExtSIOp::create(b, resultETy, dx);
2183 dy = arith::ExtSIOp::create(b, resultETy, dy);
2184 }
2185
2186 Value yScaleNExt = yScaleN;
2187 Value xScaleNExt = xScaleN;
2188
2189 const int64_t scaleBitwidth =
2190 xScaleN.getType().getIntOrFloatBitWidth();
2191 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2192 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2193 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2194 }
2195
2196 auto interpolate = [](Value val0, Value val1, Value weight1,
2197 Value scale, int inputSize,
2198 ImplicitLocOpBuilder &b) -> Value {
2199 if (inputSize == 1)
2200 return arith::MulIOp::create(b, val0, scale);
2201 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2202 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2203 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2204 return arith::AddIOp::create(b, mul0, mul1);
2205 };
2206
2207 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2208 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2209 Value result =
2210 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2211 linalg::YieldOp::create(b, result);
2212 }
2213 }
2214 }
2215
2216 rewriter.replaceOp(op, resize);
2217 return success();
2218 }
2219};
2220
2221// At the codegen level any identity operations should be removed. Any cases
2222// where identity is load-bearing (e.g. cross device computation) should be
2223// handled before lowering to codegen.
2224template <typename SrcOp>
2225class IdentityNConverter : public OpRewritePattern<SrcOp> {
2226public:
2227 using OpRewritePattern<SrcOp>::OpRewritePattern;
2228
2229 LogicalResult matchAndRewrite(SrcOp op,
2230 PatternRewriter &rewriter) const final {
2231 rewriter.replaceOp(op, op.getOperation()->getOperands());
2232 return success();
2233 }
2234};
2235
2236template <typename SrcOp>
2237class ReduceConverter : public OpRewritePattern<SrcOp> {
2238public:
2239 using OpRewritePattern<SrcOp>::OpRewritePattern;
2240
2241 LogicalResult matchAndRewrite(SrcOp reduceOp,
2242 PatternRewriter &rewriter) const final {
2243 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2244 }
2245};
2246
2247class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2248public:
2249 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2250
2251 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2252 PatternRewriter &rewriter) const final {
2253 auto loc = op.getLoc();
2254 Value input = op.getInput1();
2255 auto inputTy = cast<ShapedType>(input.getType());
2256 auto resultTy = cast<ShapedType>(op.getType());
2257 auto axis = op.getAxis();
2258
2259 SmallVector<Value> dynDims;
2260 for (int i = 0; i < inputTy.getRank(); i++) {
2261 if (inputTy.isDynamicDim(i)) {
2262 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2263 }
2264 }
2265
2266 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2267
2268 // First fill the output buffer with the init value.
2269 auto emptyTensor = tensor::EmptyOp::create(
2270 rewriter, loc, inputTy.getShape(),
2271 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2272 .getResult();
2273 SmallVector<AffineMap, 2> affineMaps = {
2274 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2275
2276 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2277 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2278 getNParallelLoopsAttrs(resultTy.getRank()),
2279 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2280 llvm::SmallVector<Value> indices;
2281 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2282 Value index =
2283 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2284 if (i == axis) {
2285 auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
2286 auto sizeMinusOne =
2287 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2288 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2289 index);
2290 }
2291
2292 indices.push_back(index);
2293 }
2294
2295 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2296 input, indices);
2297 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2298 extract.getResult());
2299 });
2300 return success();
2301 }
2302};
2303
2304// This converter translate a tile operation to a reshape, broadcast, reshape.
2305// The first reshape minimally expands each tiled dimension to include a
2306// proceding size-1 dim. This dim is then broadcasted to the appropriate
2307// multiple.
2308struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2309 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2310
2311 LogicalResult
2312 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2313 ConversionPatternRewriter &rewriter) const override {
2314 auto loc = op.getLoc();
2315 auto input = op.getInput1();
2316 auto inputTy = cast<ShapedType>(input.getType());
2317 auto inputShape = inputTy.getShape();
2318 auto resultTy = cast<ShapedType>(op.getType());
2319 auto elementTy = inputTy.getElementType();
2320 int64_t rank = inputTy.getRank();
2321
2322 SmallVector<int64_t> multiples;
2323 if (failed(op.getConstantMultiples(multiples)))
2324 return failure();
2325
2326 // Broadcast the newly added dimensions to their appropriate multiple.
2327 SmallVector<int64_t, 2> genericShape;
2328 for (int i = 0; i < rank; i++) {
2329 int64_t dim = multiples[i];
2330 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2331 genericShape.push_back(inputShape[i]);
2332 }
2333
2334 SmallVector<Value> dynDims;
2335 for (int i = 0; i < inputTy.getRank(); i++) {
2336 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2337 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2338 }
2339 }
2340
2341 auto emptyTensor = tensor::EmptyOp::create(
2342 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2343
2344 // We needs to map the input shape to the non-broadcasted dimensions.
2345 SmallVector<AffineExpr, 4> dimExprs;
2346 dimExprs.reserve(rank);
2347 for (unsigned i = 0; i < rank; ++i)
2348 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2349
2350 auto readAffineMap =
2351 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
2352 rewriter.getContext());
2353
2354 SmallVector<AffineMap, 2> affineMaps = {
2355 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2356
2357 auto genericOp = linalg::GenericOp::create(
2358 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2359 ValueRange{emptyTensor}, affineMaps,
2360 getNParallelLoopsAttrs(genericShape.size()),
2361 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2362 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2363 });
2364
2365 auto shapeValue = getTosaConstShape(
2366 rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
2367 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2368 op, resultTy, genericOp.getResult(0), shapeValue);
2369 return success();
2370 }
2371};
2372
2373// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2374// op, producing two output buffers.
2375//
2376// The first output buffer contains the index of the found maximum value. It is
2377// initialized to 0 and is resulting integer type.
2378//
2379// The second output buffer contains the maximum value found. It is initialized
2380// to the minimum representable value of the input element type. After being
2381// populated by indexed_generic, this buffer is disgarded as only the index is
2382// requested.
2383//
2384// The indexed_generic op updates both the maximum value and index if the
2385// current value exceeds the running max.
2386class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2387public:
2388 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2389
2390 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2391 PatternRewriter &rewriter) const final {
2392 auto loc = argmaxOp.getLoc();
2393 Value input = argmaxOp.getInput();
2394 auto inputTy = cast<ShapedType>(input.getType());
2395 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2396 auto inElementTy = inputTy.getElementType();
2397 auto outElementTy = resultTy.getElementType();
2398 int axis = argmaxOp.getAxis();
2399 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2400
2401 if (!isa<IntegerType>(outElementTy))
2402 return rewriter.notifyMatchFailure(
2403 argmaxOp,
2404 "tosa.arg_max to linalg.* requires integer-like result type");
2405
2406 SmallVector<Value> dynDims;
2407 for (int i = 0; i < inputTy.getRank(); i++) {
2408 if (inputTy.isDynamicDim(i) && i != axis) {
2409 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2410 }
2411 }
2412
2413 // First fill the output buffer for the index.
2414 auto emptyTensorIdx =
2415 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2416 outElementTy, dynDims)
2417 .getResult();
2418 auto fillValueIdx = arith::ConstantOp::create(
2419 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2420 auto filledTensorIdx =
2421 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
2422 ValueRange{emptyTensorIdx})
2423 .result();
2424
2425 // Second fill the output buffer for the running max.
2426 auto emptyTensorMax =
2427 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2428 dynDims)
2429 .getResult();
2430 auto fillValueMaxAttr =
2431 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2432
2433 if (!fillValueMaxAttr)
2434 return rewriter.notifyMatchFailure(
2435 argmaxOp, "unsupported tosa.argmax element type");
2436
2437 auto fillValueMax =
2438 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2439 auto filledTensorMax =
2440 linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
2441 ValueRange{emptyTensorMax})
2442 .result();
2443
2444 // We need to reduce along the arg-max axis, with parallel operations along
2445 // the rest.
2446 SmallVector<utils::IteratorType, 4> iteratorTypes;
2447 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2448 iteratorTypes[axis] = utils::IteratorType::reduction;
2449
2450 SmallVector<AffineExpr, 2> srcExprs;
2451 SmallVector<AffineExpr, 2> dstExprs;
2452 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2453 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2454 if (axis != i)
2455 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
2456 }
2457
2458 bool didEncounterError = false;
2459 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
2460 rewriter.getContext());
2461 auto linalgOp = linalg::GenericOp::create(
2462 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2463 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2464 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2465 ValueRange blockArgs) {
2466 auto newValue = blockArgs[0];
2467 auto oldIndex = blockArgs[1];
2468 auto oldValue = blockArgs[2];
2469
2470 Value newIndex = arith::IndexCastOp::create(
2471 rewriter, nestedLoc, oldIndex.getType(),
2472 linalg::IndexOp::create(rewriter, loc, axis));
2473
2474 Value predicate;
2475 if (isa<FloatType>(inElementTy)) {
2476 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2477 // Only update index & max value for non NaN values. If all
2478 // values are NaNs, the initial index will be return which is 0.
2479 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2480 arith::CmpFPredicate::OGT,
2481 newValue, oldValue);
2482 } else {
2483 // Update max value if either of the following is true:
2484 // - new value is bigger
2485 // - cur max is not NaN and new value is NaN
2486 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2487 arith::CmpFPredicate::UGT,
2488 newValue, oldValue);
2489 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2490 arith::CmpFPredicate::ORD,
2491 oldValue, oldValue);
2492 predicate = arith::AndIOp::create(
2493 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2494 }
2495 } else if (isa<IntegerType>(inElementTy)) {
2496 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2497 arith::CmpIPredicate::sgt,
2498 newValue, oldValue);
2499 } else {
2500 didEncounterError = true;
2501 return;
2502 }
2503
2504 auto resultMax = arith::SelectOp::create(
2505 rewriter, nestedLoc, predicate, newValue, oldValue);
2506 auto resultIndex = arith::SelectOp::create(
2507 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2508 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2509 ValueRange({resultIndex, resultMax}));
2510 });
2511
2512 if (didEncounterError)
2513 return rewriter.notifyMatchFailure(
2514 argmaxOp, "unsupported tosa.argmax element type");
2515
2516 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2517 return success();
2518 }
2519};
2520
2521class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2522public:
2523 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2524 LogicalResult
2525 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2526 ConversionPatternRewriter &rewriter) const final {
2527 auto input = adaptor.getOperands()[0];
2528 auto indices = adaptor.getOperands()[1];
2529
2530 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2531 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2532 if (!valuesTy || !resultTy)
2533 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2534
2535 auto dynamicDims = inferDynamicDimsForGather(
2536 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2537
2538 auto resultElementTy = resultTy.getElementType();
2539
2540 auto loc = op.getLoc();
2541 auto emptyTensor =
2542 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2543 resultElementTy, dynamicDims)
2544 .getResult();
2545
2546 SmallVector<AffineMap, 2> affineMaps = {
2548 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2549 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2550 rewriter.getContext()),
2551 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2552
2553 auto genericOp = linalg::GenericOp::create(
2554 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2555 ValueRange{emptyTensor}, affineMaps,
2556 getNParallelLoopsAttrs(resultTy.getRank()),
2557 [&](OpBuilder &b, Location loc, ValueRange args) {
2558 auto indexValue = args[0];
2559 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2560 Value index1 = arith::IndexCastOp::create(
2561 rewriter, loc, rewriter.getIndexType(), indexValue);
2562 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2563 Value extract = tensor::ExtractOp::create(
2564 rewriter, loc, input, ValueRange{index0, index1, index2});
2565 linalg::YieldOp::create(rewriter, loc, extract);
2566 });
2567 rewriter.replaceOp(op, genericOp.getResult(0));
2568 return success();
2569 }
2570
2571 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2572 Location loc,
2573 Value values,
2574 Value indices) {
2575 llvm::SmallVector<Value> results;
2576
2577 auto addDynamicDimension = [&](Value source, int64_t dim) {
2578 auto sz = tensor::getMixedSize(builder, loc, source, dim);
2579 if (auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2580 results.push_back(dimValue);
2581 };
2582
2583 addDynamicDimension(values, 0);
2584 addDynamicDimension(indices, 1);
2585 addDynamicDimension(values, 2);
2586 return results;
2587 }
2588};
2589
2590// Lowerings the TableOp to a series of gathers and numerica operations. This
2591// includes interpolation between the high/low values. For the I8 varient, this
2592// simplifies to a single gather operation.
2593class TableConverter : public OpRewritePattern<tosa::TableOp> {
2594public:
2595 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2596
2597 LogicalResult matchAndRewrite(tosa::TableOp op,
2598 PatternRewriter &rewriter) const final {
2599 auto loc = op.getLoc();
2600 Value input = op.getInput1();
2601 Value table = op.getTable();
2602 auto inputTy = cast<ShapedType>(input.getType());
2603 auto tableTy = cast<ShapedType>(table.getType());
2604 auto resultTy = cast<ShapedType>(op.getType());
2605
2606 auto inputElementTy = inputTy.getElementType();
2607 auto tableElementTy = tableTy.getElementType();
2608 auto resultElementTy = resultTy.getElementType();
2609
2610 SmallVector<Value> dynDims;
2611 for (int i = 0; i < resultTy.getRank(); ++i) {
2612 if (inputTy.isDynamicDim(i)) {
2613 dynDims.push_back(
2614 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2615 }
2616 }
2617
2618 auto emptyTensor =
2619 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2620 resultElementTy, dynDims)
2621 .getResult();
2622
2623 SmallVector<AffineMap, 2> affineMaps = {
2624 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2625 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2626
2627 auto genericOp = linalg::GenericOp::create(
2628 rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor},
2629 affineMaps, getNParallelLoopsAttrs(resultTy.getRank()));
2630 rewriter.replaceOp(op, genericOp.getResult(0));
2631
2632 {
2633 OpBuilder::InsertionGuard regionGuard(rewriter);
2634 Block *block = rewriter.createBlock(
2635 &genericOp.getRegion(), genericOp.getRegion().end(),
2636 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2637
2638 auto inputValue = block->getArgument(0);
2639 rewriter.setInsertionPointToStart(block);
2640 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2641 resultElementTy.isInteger(8)) {
2642 Value index = arith::IndexCastOp::create(
2643 rewriter, loc, rewriter.getIndexType(), inputValue);
2644 Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128);
2645 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2646 index, offset);
2647 Value extract =
2648 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2649 linalg::YieldOp::create(rewriter, loc, extract);
2650 return success();
2651 }
2652
2653 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2654 resultElementTy.isInteger(32)) {
2655 Value extend = arith::ExtSIOp::create(
2656 rewriter, loc, rewriter.getI32Type(), inputValue);
2657
2658 auto offset = arith::ConstantOp::create(
2659 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2660 auto seven = arith::ConstantOp::create(rewriter, loc,
2661 rewriter.getI32IntegerAttr(7));
2662 auto one = arith::ConstantOp::create(rewriter, loc,
2663 rewriter.getI32IntegerAttr(1));
2664 auto b1111111 = arith::ConstantOp::create(
2665 rewriter, loc, rewriter.getI32IntegerAttr(127));
2666
2667 // Compute the index and fractional part from the input value:
2668 // value = value + 32768
2669 // index = value >> 7;
2670 // fraction = 0x01111111 & value
2671 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2672 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2673 Value fraction =
2674 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2675
2676 // Extract the base and next values from the table.
2677 // base = (int32_t) table[index];
2678 // next = (int32_t) table[index + 1];
2679 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2680
2681 index = arith::IndexCastOp::create(rewriter, loc,
2682 rewriter.getIndexType(), index);
2683 indexPlusOne = arith::IndexCastOp::create(
2684 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2685
2686 Value base =
2687 tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index});
2688 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2689 ValueRange{indexPlusOne});
2690
2691 base =
2692 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2693 next =
2694 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2695
2696 // Use the fractional part to interpolate between the input values:
2697 // result = (base << 7) + (next - base) * fraction
2698 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2699 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2700 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2701 Value result =
2702 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2703
2704 linalg::YieldOp::create(rewriter, loc, result);
2705
2706 return success();
2707 }
2708 }
2709
2710 return rewriter.notifyMatchFailure(
2711 op, "unable to create body for tosa.table op");
2712 }
2713};
2714
2715struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2716 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2717
2718 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
2719
2720 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2721 OpFoldResult ofr) {
2722 auto one = arith::ConstantIndexOp::create(builder, loc, 1);
2723 auto two = arith::ConstantIndexOp::create(builder, loc, 2);
2724
2725 auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
2726 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2727 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2728 return getAsOpFoldResult(plusOne);
2729 }
2730
2731 static RankedTensorType
2732 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2733 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2734 // Get [N, H, W]
2735 auto dims = tensor::getMixedSizes(builder, loc, input);
2736
2737 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2738 // output tensors.
2739 dims[2] = halfPlusOne(builder, loc, dims[2]);
2740
2741 llvm::SmallVector<int64_t, 3> staticSizes;
2742 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2743
2744 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2745 return RankedTensorType::get(staticSizes, elementType);
2746 }
2747
2748 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2749 RankedTensorType type,
2750 llvm::ArrayRef<Value> dynamicSizes) {
2751 auto emptyTensor =
2752 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2753 auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
2754 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2755 auto filledTensor =
2756 linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
2757 ValueRange{emptyTensor})
2758 .result();
2759 return filledTensor;
2760 }
2761
2762 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2763 FloatType type, Value value) {
2764 auto integerVal = arith::IndexCastUIOp::create(
2765 builder, loc,
2766 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2767 : builder.getI32Type(),
2768 value);
2769
2770 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2771 }
2772
2773 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2774 FloatType type, int64_t index) {
2775 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2776 return castIndexToFloat(builder, loc, type, indexVal);
2777 }
2778
2779 template <typename... Args>
2780 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2781 Args... args) {
2782 return {builder.getAffineDimExpr(args)...};
2783 }
2784
2785 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2786 PatternRewriter &rewriter) const override {
2787 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2788 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2789 return rewriter.notifyMatchFailure(rfft2d,
2790 "only supports ranked tensors");
2791 }
2792
2793 auto loc = rfft2d.getLoc();
2794 auto input = rfft2d.getInputReal();
2795 auto elementType =
2796 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2797 if (!elementType)
2798 return rewriter.notifyMatchFailure(rfft2d,
2799 "only supports float element types");
2800
2801 // Compute the output type and set of dynamic sizes
2802 llvm::SmallVector<Value> dynamicSizes;
2803 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2804
2805 // Iterator types for the linalg.generic implementation
2806 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2807 utils::IteratorType::parallel, utils::IteratorType::parallel,
2808 utils::IteratorType::parallel, utils::IteratorType::reduction,
2809 utils::IteratorType::reduction};
2810
2811 // Inputs/outputs to the linalg.generic implementation
2812 llvm::SmallVector<Value> genericOpInputs = {input};
2813 llvm::SmallVector<Value> genericOpOutputs = {
2814 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2815 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2816
2817 // Indexing maps for input and output tensors
2818 auto indexingMaps = AffineMap::inferFromExprList(
2819 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2820 affineDimsExpr(rewriter, 0, 1, 2),
2821 affineDimsExpr(rewriter, 0, 1, 2)},
2822 rewriter.getContext());
2823
2824 // Width and height dimensions of the original input.
2825 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2826 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2827
2828 // Constants and dimension sizes
2829 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2830 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2831 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2832 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2833
2834 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2835 Value valReal = args[0];
2836 Value sumReal = args[1];
2837 Value sumImag = args[2];
2838
2839 // Indices for angle computation
2840 Value oy = linalg::IndexOp::create(builder, loc, 1);
2841 Value ox = linalg::IndexOp::create(builder, loc, 2);
2842 Value iy = linalg::IndexOp::create(builder, loc, 3);
2843 Value ix = linalg::IndexOp::create(builder, loc, 4);
2844
2845 // Calculating angle without integer parts of components as sin/cos are
2846 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2847 // / W);
2848 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2849 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2850
2851 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2852 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2853
2854 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2855 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2856
2857 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2858 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2859 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2860 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2861
2862 // realComponent = valReal * cos(angle)
2863 // imagComponent = valReal * sin(angle)
2864 auto cosAngle = math::CosOp::create(builder, loc, angle);
2865 auto sinAngle = math::SinOp::create(builder, loc, angle);
2866 auto realComponent =
2867 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2868 auto imagComponent =
2869 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2870
2871 // outReal = sumReal + realComponent
2872 // outImag = sumImag - imagComponent
2873 auto outReal =
2874 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2875 auto outImag =
2876 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2877
2878 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
2879 };
2880
2881 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2882 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2883 indexingMaps, iteratorTypes, buildBody);
2884
2885 return success();
2886 }
2887};
2888
2889struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2891
2892 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2893 PatternRewriter &rewriter) const override {
2894 if (!llvm::all_of(fft2d->getOperandTypes(),
2895 RFFT2dConverter::isRankedTensor) ||
2896 !llvm::all_of(fft2d->getResultTypes(),
2897 RFFT2dConverter::isRankedTensor)) {
2898 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2899 }
2900
2901 Location loc = fft2d.getLoc();
2902 Value input_real = fft2d.getInputReal();
2903 Value input_imag = fft2d.getInputImag();
2904 BoolAttr inverse = fft2d.getInverseAttr();
2905
2906 auto real_el_ty = cast<FloatType>(
2907 cast<ShapedType>(input_real.getType()).getElementType());
2908 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2909 cast<ShapedType>(input_imag.getType()).getElementType());
2910
2911 assert(real_el_ty == imag_el_ty);
2912
2913 // Compute the output type and set of dynamic sizes
2914 SmallVector<Value> dynamicSizes;
2915
2916 // Get [N, H, W]
2917 auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
2918
2919 SmallVector<int64_t, 3> staticSizes;
2920 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2921
2922 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2923
2924 // Iterator types for the linalg.generic implementation
2925 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2926 utils::IteratorType::parallel, utils::IteratorType::parallel,
2927 utils::IteratorType::parallel, utils::IteratorType::reduction,
2928 utils::IteratorType::reduction};
2929
2930 // Inputs/outputs to the linalg.generic implementation
2931 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2932 SmallVector<Value> genericOpOutputs = {
2933 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2934 dynamicSizes),
2935 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2936 dynamicSizes)};
2937
2938 // Indexing maps for input and output tensors
2939 auto indexingMaps = AffineMap::inferFromExprList(
2940 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2941 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2942 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2943 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2944 rewriter.getContext());
2945
2946 // Width and height dimensions of the original input.
2947 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2948 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2949
2950 // Constants and dimension sizes
2951 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2952 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2953 Value constH =
2954 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2955 Value constW =
2956 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2957
2958 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2959 Value valReal = args[0];
2960 Value valImag = args[1];
2961 Value sumReal = args[2];
2962 Value sumImag = args[3];
2963
2964 // Indices for angle computation
2965 Value oy = linalg::IndexOp::create(builder, loc, 1);
2966 Value ox = linalg::IndexOp::create(builder, loc, 2);
2967 Value iy = linalg::IndexOp::create(builder, loc, 3);
2968 Value ix = linalg::IndexOp::create(builder, loc, 4);
2969
2970 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2971 // ox) % W ) / W);
2972 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2973 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2974
2975 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2976 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2977
2978 auto iyRemFloat =
2979 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2980 auto ixRemFloat =
2981 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2982
2983 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2984 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2985
2986 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2987 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2988
2989 if (inverse.getValue()) {
2990 angle = arith::MulFOp::create(
2991 builder, loc, angle,
2992 arith::ConstantOp::create(rewriter, loc,
2993 rewriter.getFloatAttr(real_el_ty, -1.0)));
2994 }
2995
2996 // realComponent = val_real * cos(a) + val_imag * sin(a);
2997 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2998 auto cosAngle = math::CosOp::create(builder, loc, angle);
2999 auto sinAngle = math::SinOp::create(builder, loc, angle);
3000
3001 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3002 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3003 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3004
3005 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3006 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3007
3008 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3009
3010 // outReal = sumReal + realComponent
3011 // outImag = sumImag - imagComponent
3012 auto outReal =
3013 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3014 auto outImag =
3015 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3016
3017 linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag});
3018 };
3019
3020 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
3021 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3022 indexingMaps, iteratorTypes, buildBody);
3023
3024 return success();
3025 }
3026};
3027
3028} // namespace
3029
3031 const TypeConverter &converter, RewritePatternSet *patterns) {
3032
3033 // We have multiple resize coverters to handle degenerate cases.
3034 patterns->add<GenericResizeConverter>(patterns->getContext(),
3035 /*benefit=*/100);
3036 patterns->add<ResizeUnaryConverter>(patterns->getContext(),
3037 /*benefit=*/200);
3038 patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
3039 /*benefit=*/300);
3040
3041 patterns->add<
3042 // clang-format off
3043 PointwiseConverter<tosa::AddOp>,
3044 PointwiseConverter<tosa::SubOp>,
3045 PointwiseConverter<tosa::MulOp>,
3046 PointwiseConverter<tosa::IntDivOp>,
3047 PointwiseConverter<tosa::NegateOp>,
3048 PointwiseConverter<tosa::PowOp>,
3049 PointwiseConverter<tosa::ReciprocalOp>,
3050 PointwiseConverter<tosa::RsqrtOp>,
3051 PointwiseConverter<tosa::LogOp>,
3052 PointwiseConverter<tosa::ExpOp>,
3053 PointwiseConverter<tosa::AbsOp>,
3054 PointwiseConverter<tosa::SinOp>,
3055 PointwiseConverter<tosa::CosOp>,
3056 PointwiseConverter<tosa::TanhOp>,
3057 PointwiseConverter<tosa::ErfOp>,
3058 PointwiseConverter<tosa::BitwiseAndOp>,
3059 PointwiseConverter<tosa::BitwiseOrOp>,
3060 PointwiseConverter<tosa::BitwiseNotOp>,
3061 PointwiseConverter<tosa::BitwiseXorOp>,
3062 PointwiseConverter<tosa::LogicalAndOp>,
3063 PointwiseConverter<tosa::LogicalNotOp>,
3064 PointwiseConverter<tosa::LogicalOrOp>,
3065 PointwiseConverter<tosa::LogicalXorOp>,
3066 PointwiseConverter<tosa::CastOp>,
3067 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3068 PointwiseConverter<tosa::LogicalRightShiftOp>,
3069 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3070 PointwiseConverter<tosa::ClzOp>,
3071 PointwiseConverter<tosa::SelectOp>,
3072 PointwiseConverter<tosa::GreaterOp>,
3073 PointwiseConverter<tosa::GreaterEqualOp>,
3074 PointwiseConverter<tosa::EqualOp>,
3075 PointwiseConverter<tosa::MaximumOp>,
3076 PointwiseConverter<tosa::MinimumOp>,
3077 PointwiseConverter<tosa::CeilOp>,
3078 PointwiseConverter<tosa::FloorOp>,
3079 PointwiseConverter<tosa::ClampOp>,
3080 PointwiseConverter<tosa::SigmoidOp>
3081 >(converter, patterns->getContext());
3082
3083 patterns->add<
3084 IdentityNConverter<tosa::IdentityOp>,
3085 ReduceConverter<tosa::ReduceAllOp>,
3086 ReduceConverter<tosa::ReduceAnyOp>,
3087 ReduceConverter<tosa::ReduceMinOp>,
3088 ReduceConverter<tosa::ReduceMaxOp>,
3089 ReduceConverter<tosa::ReduceSumOp>,
3090 ReduceConverter<tosa::ReduceProductOp>,
3091 ArgMaxConverter,
3092 GatherConverter,
3093 RescaleConverter,
3094 ReverseConverter,
3095 RFFT2dConverter,
3096 FFT2dConverter,
3097 TableConverter,
3098 TileConverter>(patterns->getContext());
3099 // clang-format on
3100}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
Definition Block.h:139
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
FloatType getF32Type()
Definition Builders.cpp:47
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
IntegerAttr getI8IntegerAttr(int8_t value)
Definition Builders.cpp:225
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...