MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/Dialect/Traits.h"
23#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/APInt.h"
29
30#include <functional>
31
32using namespace mlir;
33using namespace mlir::tosa;
34
35namespace {
36OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
37 return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
38}
39} // namespace
40
41//===----------------------------------------------------------------------===//
42// Operator Canonicalizers.
43//===----------------------------------------------------------------------===//
44
45//===----------------------------------------------------------------------===//
46// Tensor Data Engine Operators.
47//===----------------------------------------------------------------------===//
48
49// Check that the zero point of the tensor and padding operations are aligned.
50static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
51 // Check that padConst is a constant value and a scalar tensor
52 DenseElementsAttr padConstAttr;
53 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
54 (padConstAttr.size() != 1)) {
55 return false;
56 }
57
58 // Check that floating point pad is zero
59 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
60 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
61 return padConstVal == 0.0f;
62 }
63
64 // Check that the zp and padConst align for the integer (quantized) case
65 if (auto padConstIntAttr =
66 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
68 // Check that zp is a constant value and a scalar tensor
69 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
70 return false;
71 }
72
73 // Check equality
74 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
75 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
76 return zpVal == padConstVal;
77 }
78
79 // Bail-out on unsupported type
80 return false;
81}
82
83namespace {
84template <typename OpTy>
85struct PoolPadFoldAdaptor;
86
87template <>
88struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
89 using OpTy = tosa::MaxPool2dOp;
90 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
91 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
92 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
93 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
94 return false;
95 return true;
96 }
97 static bool checkPadConstCompliance(OpTy, Value padConst) {
98 // Check that padConst is a constant value and a scalar tensor
99 DenseElementsAttr padConstAttr;
100 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
101 padConstAttr.size() != 1) {
102 return false;
103 }
104
105 // Pad needs to be in the minimum value to be able to merge
106 if (auto padConstFpAttr =
107 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
108 const APFloat padConstVal = *padConstFpAttr.begin();
109 const APFloat lowestVal =
110 APFloat::getLargest(padConstVal.getSemantics(), true);
111 return padConstVal == lowestVal;
112 }
113 if (auto padConstIntAttr =
114 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
115 const APInt padConstVal = *padConstIntAttr.begin();
116 const unsigned int bitWidth = padConstVal.getBitWidth();
117 const APInt lowestVal =
118 padConstIntAttr.getElementType().isUnsignedInteger()
119 ? APInt::getZero(bitWidth)
120 : APInt::getSignedMinValue(bitWidth);
121 return padConstVal == lowestVal;
122 }
123
124 // Bail-out on unsupported type
125 return false;
126 }
127 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
128 Value padInput, ArrayRef<int64_t> newPad) {
129 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
130 op, op.getType(), padInput, op.getKernel(), op.getStride(),
131 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
132 }
133};
134
135template <typename OpTy>
136struct ConvPadFoldAdaptor {
137 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
138 return true;
139 }
140 static bool checkPadConstCompliance(OpTy op, Value padConst) {
141 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
142 }
143 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
144 Value padInput, ArrayRef<int64_t> newPad) {
145 rewriter.replaceOpWithNewOp<OpTy>(
146 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
147 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
148 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
149 }
150};
151
152// Pattern attempts to fold a `tosa.pad` operator to a following tensor
153// operation like `tosa.conv2d` by merging the padding associated with the
154// pad operator directly to the implicit padding of the tensor operation.
155// This helps eliminate the explicit padding operator if unused.
156template <typename OpTy, typename AdaptorTy>
157struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
158 using OpRewritePattern<OpTy>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(OpTy tensorOp,
161 PatternRewriter &rewriter) const override {
162 // Check producer is a tosa::PadOp
163 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
164 if (!padOp)
165 return rewriter.notifyMatchFailure(tensorOp,
166 "Producer must be a tosa::PadOp.");
167
168 // Validate that tensor operation has sane padding
169 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
170 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
171 return rewriter.notifyMatchFailure(
172 tensorOp, "Tensor operation padding shall have 4 elements.");
173
174 // Validate tosa::PadOp padding
175 DenseIntElementsAttr padOpPadding;
176 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
177 return rewriter.notifyMatchFailure(
178 tensorOp,
179 "The `padding` input specified on the tosa::PadOp must be constant.");
180 }
181 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
182 // C_after
183 if (padOpPadding.size() != 8)
184 return rewriter.notifyMatchFailure(tensorOp,
185 "Pad padding should have 8 elements.");
186 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
187 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
188 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
189 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
190 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
191 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
192 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
193 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
194
195 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
196 return rewriter.notifyMatchFailure(
197 tensorOp, "Folding padding in N or C dimensions is not supported.");
198
199 // Fold padding from Pad into the tensor operation
200 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
201 SmallVector<int64_t> foldedPad(tensorOpPad.size());
202 foldedPad[0] = padHBefore + tensorOpPad[0];
203 foldedPad[1] = padHAfter + tensorOpPad[1];
204 foldedPad[2] = padWBefore + tensorOpPad[2];
205 foldedPad[3] = padWAfter + tensorOpPad[3];
206
207 // Check kernel related restrictions
208 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
209 return rewriter.notifyMatchFailure(
210 tensorOp, "Padding size not aligned with kernel restrictions.");
211 }
212
213 // Check padding constant restrictions
214 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
215 return rewriter.notifyMatchFailure(
216 tensorOp,
217 "Padding constant is not aligned with operator zero-point.");
218 }
219
220 // Check that padding doesn't grow more than 8K level (8192) for now
221 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
222 return rewriter.notifyMatchFailure(
223 tensorOp, "Padding size more than the 8K level limit.");
224 }
225
226 // Create operator
227 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
228 foldedPad);
229
230 return success();
231 }
232};
233} // namespace
234
235void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
236 MLIRContext *context) {
237 results.add<
238 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
239 context);
240}
241
242void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
243 MLIRContext *context) {
244 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
245 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
246 context);
247}
248
250 : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
252
253 LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
254 PatternRewriter &rewriter) const override {
258 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
259 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
260 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
261 return rewriter.notifyMatchFailure(
262 op, "expected constant kernel, stride, and pad operands");
263
264 auto replacement = tosa::AvgPool2dOp::create(
265 rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
266 op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
267 rewriter.getDenseI64ArrayAttr(stride),
268 rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
269 rewriter.replaceOp(op, replacement.getOutput());
270 return success();
271 }
272};
273
274void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
275 RewritePatternSet &results, MLIRContext *context) {
276 results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
277}
278
279struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
281
282 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
283 PatternRewriter &rewriter) const override {
284 Value input = op.getInput();
285 Value output = op.getOutput();
286 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
287 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
288
289 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
290 return failure();
291 }
292
293 // If the output and input shapes are 1x1, then this is a no op.
294 ArrayRef<int64_t> outputShape = outputType.getShape();
295 if (outputShape[1] != 1 || outputShape[2] != 1) {
296 return failure();
297 }
298
299 ArrayRef<int64_t> inputShape = inputType.getShape();
300 if (inputShape[1] != 1 || inputShape[2] != 1) {
301 return failure();
302 }
303
304 rewriter.replaceOp(op, input);
305 return success();
306 }
307};
308
309void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
310 MLIRContext *context) {
311 results.add<MaxPool2dIsNoOp,
312 FoldPadToTensorOp<tosa::MaxPool2dOp,
313 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
314 context);
315}
316
318 : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
320
321 LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
322 PatternRewriter &rewriter) const override {
326 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
327 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
328 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
329 return rewriter.notifyMatchFailure(
330 op, "expected constant kernel, stride, and pad operands");
331
332 auto replacement = tosa::MaxPool2dOp::create(
333 rewriter, op.getLoc(), op.getType(), op.getInput(),
334 rewriter.getDenseI64ArrayAttr(kernel),
335 rewriter.getDenseI64ArrayAttr(stride),
336 rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
337 rewriter.replaceOp(op, replacement.getOutput());
338 return success();
339 }
340};
341
342void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
343 RewritePatternSet &results, MLIRContext *context) {
344 results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
345}
346
347//===----------------------------------------------------------------------===//
348// Data Layout / Memory Reinterpretation.
349//===----------------------------------------------------------------------===//
350
351struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
352 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
353
354 LogicalResult matchAndRewrite(tosa::ConcatOp op,
355 PatternRewriter &rewriter) const override {
356 if (op.getInput1().size() != 1)
357 return failure();
358 if (op.getInput1().front().getType() != op.getType()) {
359 rewriter
360 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
361 op.getInput1().front())
362 .getResult();
363 return success();
364 }
365
366 rewriter.replaceOp(op, op.getInput1().front());
367 return success();
368 }
369};
370
371struct ConsecutiveConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
372 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
373
374 LogicalResult matchAndRewrite(tosa::ConcatOp op,
375 PatternRewriter &rewriter) const override {
376 // Rewrite consecutive concats on the same axis into a single op.
377 // Keep track of the operands so we are able to construct a new concat
378 // later. Conservatively assume that we double the number of operands when
379 // canonicalizing
380 SmallVector<Value, 8> concatOperands;
381 concatOperands.reserve(2 * op.getNumOperands());
382
383 int32_t maxNumOperands = 0;
384 if (auto targetEnvAttr = tosa::lookupTargetEnv(op))
385 maxNumOperands =
386 getTosaLevelFromEnum(targetEnvAttr.getLevel()).MAX_TENSOR_LIST_SIZE;
387
388 // Find all operands that are foldable concats
389 bool foundRewritableConcat = false;
390 for (Value operand : op.getOperands()) {
391 concatOperands.emplace_back(operand);
392
393 auto producer = operand.getDefiningOp<tosa::ConcatOp>();
394 if (!producer)
395 continue;
396
397 // Not rewritable if axes are not the same
398 if (op.getAxis() != producer.getAxis())
399 continue;
400
401 // Replace the original operand with all incoming operands
402 foundRewritableConcat = true;
403 concatOperands.pop_back();
404 llvm::append_range(concatOperands, producer->getOperands());
405 }
406
407 if (!foundRewritableConcat)
408 return rewriter.notifyMatchFailure(op,
409 "No rewritable concat operand found.");
410
411 if (maxNumOperands > 0 &&
412 concatOperands.size() > static_cast<size_t>(maxNumOperands))
413 return rewriter.notifyMatchFailure(
414 op, "Rewriting would exceed the maximum number of operands for the "
415 "target environment level.");
416
417 rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
418 op, op.getType(), concatOperands, op.getAxisAttr());
419 return success();
420 }
421};
422
423void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
424 MLIRContext *context) {
426}
427
428LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
429 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
430 if (!notOp)
431 return failure();
432 rewriter.modifyOpInPlace(op, [&]() {
433 op.getOperation()->setOperands(
434 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
435 });
436 return success();
437}
438
440 : public OpRewritePattern<tosa::TransposeOp> {
442
443 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
444 PatternRewriter &rewriter) const override {
445 // Input is also TransposeOp - transpose(transpose(A)).
446 auto innerTranspose =
447 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
448 if (!innerTranspose)
449 return rewriter.notifyMatchFailure(transposeOp,
450 "input must be transpose operation");
451
452 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
453 const llvm::ArrayRef<int32_t> innerTransposePerms =
454 innerTranspose.getPerms();
455
456 if (transposePerms.size() != innerTransposePerms.size())
457 return rewriter.notifyMatchFailure(
458 transposeOp,
459 "transpose and inner transpose perms sizes must be equal");
460 if (transposePerms.empty())
461 return rewriter.notifyMatchFailure(
462 transposeOp, "transpose perms sizes must be positive");
463
464 // Consolidate transposes into one transpose.
465 SmallVector<int32_t> perms(transposePerms.size());
466 for (int i = 0, s = transposePerms.size(); i < s; ++i)
467 perms[i] = innerTransposePerms[transposePerms[i]];
468
469 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
470 transposeOp, transposeOp.getResult().getType(),
471 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
472
473 return success();
474 }
475};
476
477// Determines the case when tosa.transpose is a tosa.reshape operation.
478struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
480
481 LogicalResult matchAndRewrite(tosa::TransposeOp op,
482 PatternRewriter &rewriter) const override {
483 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
484 return rewriter.notifyMatchFailure(
485 op, "Src is from transpose, can compose transposes");
486
487 Value result = op.getResult();
488 for (Operation *subop : result.getUsers()) {
489 if (isa_and_nonnull<tosa::TransposeOp>(subop))
490 return rewriter.notifyMatchFailure(
491 op, "Dest is used by transpose, can compose transposes");
492 }
493
494 auto input = op.getInput1();
495 auto inputTy = llvm::cast<ShapedType>(input.getType());
496 if (!inputTy.hasRank())
497 return rewriter.notifyMatchFailure(op, "Unranked input.");
498
499 int64_t numDynDims = 0;
500 for (int i = 0; i < inputTy.getRank(); ++i)
501 if (inputTy.isDynamicDim(i))
502 numDynDims++;
503
504 if (numDynDims > 1)
505 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
506
507 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
508
509 SmallVector<int64_t> nonZeroPerms;
510 nonZeroPerms.reserve(permValues.size());
511 for (auto idx : permValues) {
512 auto sz = inputTy.getDimSize(idx);
513 if (sz != 1)
514 nonZeroPerms.push_back(idx);
515 }
516
517 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
518 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
519 return rewriter.notifyMatchFailure(op,
520 "Transpose changes memory layout.");
521
522 SmallVector<int64_t> newShape;
523 newShape.reserve(inputTy.getRank());
524 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
525 newShape.push_back(inputTy.getDimSize(permValues[i]));
526
527 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
528 op, op.getType(), op.getInput1(),
529 getTosaConstShape(rewriter, op.getLoc(), newShape));
530 return success();
531 }
532};
533
534void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
535 MLIRContext *context) {
536 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
537}
538
539struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
541
542 LogicalResult matchAndRewrite(tosa::ClampOp op,
543 PatternRewriter &rewriter) const override {
544 Value input = op.getInput();
545 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
546 auto inputElementType = inputType.getElementType();
547
548 if (isa<FloatType>(inputElementType)) {
549 // Unlike integer types, floating point types can represent infinity.
550 const auto minClamp =
551 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
552 const auto maxClamp =
553 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
554 const bool isMin = minClamp.isNegInfinity();
555 const bool isMax = maxClamp.isInfinity();
556
557 if (isMin && isMax) {
558 rewriter.replaceOp(op, input);
559 return success();
560 }
561 return failure();
562 }
563
564 // i1 types are boolean in TOSA
565 const bool isBoolean = inputElementType.isInteger(1);
566 if (inputElementType.isUnsignedInteger() || isBoolean) {
567 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
568 .getValue()
569 .getZExtValue();
570 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
571 .getValue()
572 .getZExtValue();
573
574 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
575 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
576 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
577
578 if (minClamp <= intMin && maxClamp >= intMax) {
579 rewriter.replaceOp(op, input);
580 return success();
581 }
582 return failure();
583 }
584
585 if (llvm::isa<IntegerType>(inputElementType)) {
586 const int64_t minClamp =
587 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
588 const int64_t maxClamp =
589 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
590
591 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
592 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
593 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
594
595 if (minClamp <= intMin && maxClamp >= intMax) {
596 rewriter.replaceOp(op, input);
597 return success();
598 }
599 return failure();
600 }
601
602 return failure();
603 }
604};
605
606// Attempts the following transformation:
607//
608// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
609// tensor X the following identity holds:
610//
611// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
612//
613// subject to the following valid NaN propagation semantics:
614// --------------------------------------------
615// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
616// |-------------|--------------|-------------|
617// | PROPAGATE | PROPAGATE | PROPAGATE |
618// | PROPAGATE | IGNORE | IGNORE |
619// | IGNORE | PROPAGATE | INVALID |
620// | IGNORE | IGNORE | IGNORE |
621// |------------------------------------------|
622
623struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
624 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
625
626 // Helper structure to describe the range of a clamp operation.
627 template <typename T>
628 struct ClampRange {
629 ClampRange(const T &start, const T &end) : start(start), end(end) {}
632
633 // Helper function to determine if two Clamp ranges intersect.
634 bool intersects(const ClampRange<T> &otherRange) {
635 return start < otherRange.end && otherRange.start < end;
636 }
637 };
638
639 LogicalResult matchAndRewrite(tosa::ClampOp op,
640 PatternRewriter &rewriter) const override {
641 Value input = op.getInput();
642
643 // Check the input to the CLAMP op is itself a CLAMP.
644 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
645 if (!clampOp)
646 return failure();
647
648 // Check we have a valid NaN propagation combination.
649 const auto opNanMode = op.getNanMode();
650 const auto clampNanMode = clampOp.getNanMode();
651 if (opNanMode == NanPropagationMode::IGNORE &&
652 clampNanMode == NanPropagationMode::PROPAGATE)
653 return failure();
654
655 auto maxValAttr = op.getMaxValAttr();
656 auto minValAttr = op.getMinValAttr();
657 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
658 auto clampOpMinValAttr = clampOp.getMinValAttr();
659
660 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
661 if (auto quantType =
662 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
663 inputEType = getStorageElementTypeFromQuantized(quantType);
664 }
665
666 Attribute newMinValAttr, newMaxValAttr;
667 if (mlir::isa<FloatType>(inputEType)) {
668 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
669 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
670 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
671 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
672
673 // Check we have intersecting ranges.
674 const auto opMinFloat = floatMinValAttr.getValue();
675 const auto opMaxFloat = floatMaxValAttr.getValue();
676 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
677 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
678 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
679 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
680 clampOpMaxFloat);
681 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
682 return failure();
683
684 // Run the transformation.
685 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
686 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
687 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
688 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
689 } else {
690 assert(mlir::isa<IntegerType>(inputEType));
691 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
692 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
693 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
694 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
695
696 if (inputEType.isUnsignedInteger()) {
697 // Check we have intersecting ranges.
698 const auto opMinInt = intMinValAttr.getUInt();
699 const auto opMaxInt = intMaxValAttr.getUInt();
700 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
701 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
702 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
703 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
704 clampOpMaxInt);
705 if (!opRangeIntRange.intersects(clampRangeIntRange))
706 return failure();
707
708 // Run the transformation.
709 auto newMinVal = std::max(opMinInt, clampOpMinInt);
710 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
711 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
712 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
713 } else {
714 // Check we have intersecting ranges.
715 const auto opMinInt = intMinValAttr.getInt();
716 const auto opMaxInt = intMaxValAttr.getInt();
717 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
718 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
719 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
720 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
721 clampOpMaxInt);
722 if (!opRangeIntRange.intersects(clampRangeIntRange))
723 return failure();
724
725 // Run the transformation.
726 auto newMinVal = std::max(opMinInt, clampOpMinInt);
727 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
728 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
729 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
730 }
731 }
732
733 auto newMode = (opNanMode != clampNanMode)
734 ? tosa::NanPropagationMode::IGNORE
735 : opNanMode;
736
737 auto newModeAttr =
738 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
739
740 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
741 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
742 newModeAttr);
743 return success();
744 }
745};
746
747void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
748 MLIRContext *context) {
749 results.add<ClampIsNoOp>(context);
750 results.add<ClampClampOptimization>(context);
751}
752
753struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
754 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
755
756 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
757 PatternRewriter &rewriter) const override {
758 Value sliceInput = sliceOp.getInput1();
759 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
760 if (!concatOp)
761 return rewriter.notifyMatchFailure(
762 sliceOp, "slice input must be concat operation");
763
764 OperandRange inputs = concatOp.getInput1();
765 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
766 if (!concatType || !concatType.hasStaticShape())
767 return rewriter.notifyMatchFailure(
768 sliceOp, "slice input must be a static ranked tensor");
769 int32_t axis = concatOp.getAxis();
770
771 DenseElementsAttr startElems;
772 DenseElementsAttr sizeElems;
773
774 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
775 return rewriter.notifyMatchFailure(
776 sliceOp, "start of slice must be a static ranked shape");
777
778 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
779 return rewriter.notifyMatchFailure(
780 sliceOp, "size of slice must be a static ranked shape");
781
782 llvm::SmallVector<int64_t> sliceStarts =
783 llvm::to_vector(startElems.getValues<int64_t>());
784 llvm::SmallVector<int64_t> sliceSizes =
785 llvm::to_vector(sizeElems.getValues<int64_t>());
786
787 // Validate slice on the concatenated axis. Slicing along this
788 // axis should span only one of the inputs to the concatenate
789 // operation.
790 std::optional<Value> replaceWithSlice;
791 for (auto input : inputs) {
792 auto inputType = dyn_cast<RankedTensorType>(input.getType());
793 if (!inputType || !inputType.hasStaticShape())
794 return rewriter.notifyMatchFailure(
795 sliceOp, "concat input must be a static ranked tensor");
796
797 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
798 inputType.getDimSize(axis)) {
799 auto start_op =
800 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
801 auto size_op =
802 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
803 replaceWithSlice =
804 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
805 input, start_op, size_op)
806 .getResult();
807 break;
808 }
809 sliceStarts[axis] -= inputType.getDimSize(axis);
810 }
811
812 if (!replaceWithSlice)
813 return rewriter.notifyMatchFailure(
814 sliceOp, "corresponding concat input not found for slice");
815
816 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
817 return success();
818 }
819};
820
821struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
822 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
823
824 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
825 PatternRewriter &rewriter) const override {
826 Value sliceInput = sliceOp.getInput1();
827
828 // Check if producer is a PadOp
829 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
830 if (!padOp)
831 return rewriter.notifyMatchFailure(sliceOp,
832 "slice input must be a pad operation");
833
834 // Check PadOp has a single consumer
835 if (!padOp->hasOneUse())
836 return rewriter.notifyMatchFailure(sliceOp,
837 "pad shall have a single consumer");
838
839 // Check input is statically ranked
840 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
841 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
842 if (!inputTy || !padTy || !inputTy.hasRank())
843 return rewriter.notifyMatchFailure(sliceOp,
844 "slice input must be a ranked tensor");
845
846 // Validate and extract tosa::PadOp padding
847 DenseIntElementsAttr paddingElems;
848 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
849 return rewriter.notifyMatchFailure(
850 sliceOp,
851 "`padding` input specified on the tosa::PadOp must be constant.");
852 }
853 llvm::SmallVector<int64_t> padPaddings =
854 llvm::to_vector(paddingElems.getValues<int64_t>());
855
856 // Extract slice parameters
857 DenseElementsAttr startElems;
858 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
859 return rewriter.notifyMatchFailure(
860 sliceOp, "start of slice must be a static ranked shape");
861 llvm::SmallVector<int64_t> sliceStarts =
862 llvm::to_vector(startElems.getValues<int64_t>());
863
864 DenseElementsAttr sizeElems;
865 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
866 return rewriter.notifyMatchFailure(
867 sliceOp, "size of slice must be a static ranked shape");
868 llvm::SmallVector<int64_t> sliceSizes =
869 llvm::to_vector(sizeElems.getValues<int64_t>());
870
871 // Check if dynamic dimensions are sliced
872 const int64_t rank = inputTy.getRank();
873 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
874 const bool isDimDynamic = inputTy.isDynamicDim(i);
875 const bool isDimSliced =
876 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
877
878 return isDimDynamic && isDimSliced;
879 })) {
880 return rewriter.notifyMatchFailure(
881 sliceOp, "axis that are sliced shall be statically known.");
882 }
883
884 // Update the parameters
885 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
886 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
887 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
888 bool updated = false;
889
890 for (int64_t i = 0; i < rank; ++i) {
891 const int64_t padLo = padPaddings[i * 2];
892 const int64_t padHi = padPaddings[i * 2 + 1];
893 const int64_t sliceStart = sliceStarts[i];
894 const int64_t sliceSize = sliceSizes[i];
895 const int64_t sliceEnd = sliceStart + sliceSize;
896
897 // If dimension is dynamic pass-through
898 if (inputTy.isDynamicDim(i)) {
899 newPadPaddings[i * 2] = padLo;
900 newPadPaddings[i * 2 + 1] = padHi;
901 newSliceStarts[i] = sliceStart;
902 continue;
903 }
904
905 // Handle static dimensions
906 const int64_t dimSize = inputTy.getShape()[i];
907 const int64_t dimTotal = padLo + dimSize + padHi;
908
909 // Check slice within bounds
910 if (sliceStart < 0 || sliceEnd > dimTotal)
911 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
912
913 // Compute updated slice start parameter
914 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
915 newSliceStarts[i] = newSliceStart;
916 updated |= newSliceStart != sliceStart;
917
918 // Compute updated pad parameters
919 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
920 const int64_t newPadHi =
921 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
922 newPadPaddings[i * 2] = newPadLo;
923 newPadPaddings[i * 2 + 1] = newPadHi;
924 updated |= (newPadLo != padLo) || (newPadHi != padHi);
925
926 // Calculate new pad output shape
927 newPadShape[i] =
928 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
929 }
930
931 // Check that we actually need to proceed with the rewrite
932 if (!updated)
933 return rewriter.notifyMatchFailure(
934 sliceOp, "terminate condition; nothing to rewrite");
935
936 // Create a PadOp with updated padding
937 auto newPaddingsOp =
938 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
939 auto newPadTy =
940 RankedTensorType::get(newPadShape, inputTy.getElementType());
941 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
942 padOp.getInput1(), newPaddingsOp,
943 padOp.getPadConst());
944
945 // Update SliceOp and point to new PadOp
946 auto newStartOp =
947 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
948 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
949 newPadOp.getResult(), newStartOp,
950 sliceOp.getSize());
951
952 return success();
953 }
954};
955
956// Update size operand of tosa.slice if size has dynamic dims but corresponding
957// output dim is static
959 : public OpRewritePattern<tosa::SliceOp> {
960 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
961
962 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
963 PatternRewriter &rewriter) const override {
964 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
965 if (!resultType.hasRank())
966 return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
967
968 ElementsAttr sizeElems;
969 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
970 return rewriter.notifyMatchFailure(
971 sliceOp, "size of slice must be a static ranked shape");
972 }
973
974 llvm::SmallVector<int64_t> sliceSizes =
975 llvm::to_vector(sizeElems.getValues<int64_t>());
976
977 bool replaceSliceSize{false};
978 // if size op has kInferableDimSize indicating dynamic shape but
979 // corresponding dim on the output is statically known, update size to match
980 // with known output dim shape
981 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
982 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
983 sliceSizes[index] = resultType.getDimSize(index);
984 replaceSliceSize = true;
985 }
986 }
987
988 if (!replaceSliceSize) {
989 return rewriter.notifyMatchFailure(
990 sliceOp, "no dimension of size of slice is dynamic that resolves "
991 "to static output shape");
992 }
993
994 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
995 auto newSliceOp =
996 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
997 sliceOp.getInput1(), sliceOp.getStart(), size_op);
998
999 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
1000 return success();
1001 }
1002};
1003
1004void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005 MLIRContext *context) {
1006 results.add<ConcatSliceOptimization, PadSliceOptimization,
1007 SliceDynamicSizeCanonicalization>(context);
1008}
1009
1011 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
1012
1013 LogicalResult matchAndRewrite(tosa::CastOp castOp,
1014 PatternRewriter &rewriter) const override {
1015 const Value castInput = castOp.getInput();
1016 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
1017 if (!innerCastOp)
1018 return rewriter.notifyMatchFailure(castOp,
1019 "input must be cast operation");
1020
1021 const Value innerCastInput = innerCastOp.getInput();
1022
1023 const ShapedType innerInputType =
1024 llvm::cast<ShapedType>(innerCastInput.getType());
1025 const ShapedType innerOutputType =
1026 llvm::cast<ShapedType>(innerCastOp.getType());
1027 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
1028
1029 const Type innerInputElemType = innerInputType.getElementType();
1030 const Type innerOutputElemType = innerOutputType.getElementType();
1031 const Type outerOutputElemType = outerOutputType.getElementType();
1032
1033 const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
1034 outerOutputElemType};
1035
1036 if (llvm::any_of(types, [](const Type type) {
1037 // Support a specific set of floating point types since we need to be
1038 // careful in not introducing unsupported type combinations
1039 return !(type.isInteger() ||
1040 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1041 Float16Type, Float32Type>(type));
1042 }))
1043 return rewriter.notifyMatchFailure(
1044 castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
1045 "supported");
1046
1047 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
1048 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
1049 return rewriter.notifyMatchFailure(
1050 castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1051 "legal in TOSA");
1052 }
1053
1054 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1055 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1056 return rewriter.notifyMatchFailure(
1057 castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1058 "legal in TOSA");
1059 }
1060
1061 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1062 outerOutputElemType.isInteger()) {
1063 return rewriter.notifyMatchFailure(
1064 castOp, "avoid introducing fp8 -> integer casts which are not "
1065 "legal in TOSA");
1066 }
1067
1068 if (innerInputElemType.isInteger() &&
1069 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1070 return rewriter.notifyMatchFailure(
1071 castOp, "avoid introducing integer -> fp8 casts which are not "
1072 "legal in TOSA");
1073 }
1074
1075 if (llvm::isa<Float16Type>(innerInputElemType) &&
1076 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1077 return rewriter.notifyMatchFailure(
1078 castOp, "avoid introducing fp16 -> bf16 casts which are not "
1079 "legal in TOSA");
1080 }
1081
1082 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1083 llvm::isa<Float16Type>(outerOutputElemType)) {
1084 return rewriter.notifyMatchFailure(
1085 castOp, "avoid introducing bf16 -> fp16 casts which are not "
1086 "legal in TOSA");
1087 }
1088
1089 const auto isIntegerOneOfWidth = [](Type type, size_t bitwidth1,
1090 size_t bitwidth2) {
1091 return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
1092 };
1093
1094 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1095 outerOutputElemType.isInteger(64)) {
1096 return rewriter.notifyMatchFailure(
1097 castOp, "avoid introducing i8/i16 -> i64 casts which are not "
1098 "legal in TOSA");
1099 }
1100
1101 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1102 !outerOutputElemType.isInteger()) {
1103 return rewriter.notifyMatchFailure(
1104 castOp, "avoid introducing bool/i64 to float casts which are not "
1105 "supported in all versions of TOSA");
1106 }
1107
1108 if (!innerInputElemType.isInteger() &&
1109 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1110 return rewriter.notifyMatchFailure(
1111 castOp, "avoid introducing float to bool/i64 casts which are not "
1112 "supported in all versions of TOSA");
1113 }
1114
1115 // Check that the cast we're considering for removal is non-narrowing
1116 if (isNarrowingCast(innerInputType, innerOutputType))
1117 return rewriter.notifyMatchFailure(castOp,
1118 "inner cast operation is narrowing");
1119
1120 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
1121 innerCastInput);
1122
1123 return success();
1124 }
1125
1126 bool supportsNaN(const llvm::fltSemantics &semantics) const {
1127 return semantics.nonFiniteBehavior !=
1128 llvm::fltNonfiniteBehavior::FiniteOnly;
1129 }
1130
1131 bool supportsInf(const llvm::fltSemantics &semantics) const {
1132 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1133 }
1134
1135 bool isNarrowingCast(const ShapedType inType,
1136 const ShapedType outType) const {
1137
1138 if (inType.getElementType().isInteger() &&
1139 outType.getElementType().isInteger()) {
1140
1141 const auto inTypeSignedness =
1142 cast<IntegerType>(inType.getElementType()).getSignedness();
1143 const auto outTypeSignedness =
1144 cast<IntegerType>(outType.getElementType()).getSignedness();
1145
1146 return (inTypeSignedness != outTypeSignedness ||
1147 inType.getElementTypeBitWidth() >
1148 outType.getElementTypeBitWidth());
1149 }
1150
1151 if (inType.getElementType().isFloat() &&
1152 outType.getElementType().isFloat()) {
1153
1154 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1155 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1156 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1157 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1158
1159 // If the list of supported types needs to be updated in the future, the
1160 // check down below will need to be revised, for example to account for
1161 // unsigned floating point types, or types that use negative zero as the
1162 // representation for NaN.
1163 [[maybe_unused]] const auto isSupported = [](Type elemType) {
1164 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1165 Float16Type, Float32Type>(elemType);
1166 };
1167
1168 assert(isSupported(inElemTy) &&
1169 "unsupported input element type in isNarrowingCast");
1170 assert(isSupported(outElemTy) &&
1171 "unsupported output element type in isNarrowingCast");
1172
1173 return (
1174 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1175 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1176 inTypeSemantics.precision > outTypeSemantics.precision ||
1177 (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
1178 (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
1179 }
1180
1181 // While some cases of int -> float casts can be non-narrowing, consider
1182 // them narrowing for the purposes of this optimization
1183 return true;
1184 }
1185};
1186
1187void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.add<NonNarrowingCastsOptimization>(context);
1190}
1191
1193 : public OpRewritePattern<tosa::CastToBlockScaledOp> {
1194 using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
1195
1196 LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
1197 PatternRewriter &rewriter) const override {
1198 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1199 auto castFromBlockScaledOp =
1200 castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
1201 if (!castFromBlockScaledOp)
1202 return rewriter.notifyMatchFailure(
1203 castToBlockScaledOp,
1204 "input must be cast_from_block_scaled operation");
1205
1206 const Value innerData = castFromBlockScaledOp.getInputData();
1207 const Value innerScale = castFromBlockScaledOp.getInputScale();
1208 const auto innerDataTy = llvm::cast<ShapedType>(innerData.getType());
1209 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.getType());
1210
1211 const Value outerData = castToBlockScaledOp.getOutputData();
1212 const Value outerScale = castToBlockScaledOp.getOutputScale();
1213 const auto outerDataTy = llvm::cast<ShapedType>(outerData.getType());
1214 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.getType());
1215
1216 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1217 return rewriter.notifyMatchFailure(
1218 castToBlockScaledOp,
1219 "inputs types to cast_from_block_scaled operation must match output "
1220 "types to cast_to_block_scaled");
1221 }
1222
1223 if (castFromBlockScaledOp.getBlockSize() !=
1224 castToBlockScaledOp.getBlockSize()) {
1225 return rewriter.notifyMatchFailure(
1226 castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
1227 "cast_to_block_scaled must match");
1228 }
1229
1230 rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
1231
1232 return success();
1233 }
1234};
1235
1236void CastToBlockScaledOp::getCanonicalizationPatterns(
1237 RewritePatternSet &results, MLIRContext *context) {
1238 results.add<CancellingBlockScaledCastsOptimization>(context);
1239}
1240
1241struct RowGatherToGather : public OpRewritePattern<tosa::RowGatherOp> {
1242 using OpRewritePattern<tosa::RowGatherOp>::OpRewritePattern;
1243
1244 LogicalResult matchAndRewrite(tosa::RowGatherOp op,
1245 PatternRewriter &rewriter) const override {
1246 const FailureOr<int32_t> rowCount =
1248 if (failed(rowCount) || rowCount.value() != 1)
1249 return failure();
1250
1251 rewriter.replaceOpWithNewOp<tosa::GatherOp>(
1252 op, op.getOutput().getType(), op.getValues(), op.getIndices());
1253 return success();
1254 }
1255};
1256
1257void RowGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
1258 MLIRContext *context) {
1259 results.add<RowGatherToGather>(context);
1260}
1261
1262//===----------------------------------------------------------------------===//
1263// Operator Folders.
1264//===----------------------------------------------------------------------===//
1265
1266template <typename Folder>
1267static DenseElementsAttr
1269 bool foldDenseValues = false) {
1270 if (!lhs || !rhs)
1271 return {};
1272
1273 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1274 return {};
1275
1276 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
1277 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
1278 if (lETy != rETy)
1279 return {};
1280
1281 if (lhs.isSplat() && rhs.isSplat()) {
1282 if (isa<FloatType>(lETy)) {
1283 const APFloat l = lhs.getSplatValue<APFloat>();
1284 const APFloat r = rhs.getSplatValue<APFloat>();
1285 const auto maybeResult = Folder::fold(l, r);
1286 if (failed(maybeResult))
1287 return {};
1288 return DenseElementsAttr::get(returnTy, maybeResult.value());
1289 }
1290
1291 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1292 const APInt l = lhs.getSplatValue<APInt>();
1293 const APInt r = rhs.getSplatValue<APInt>();
1294 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1295 if (failed(maybeResult))
1296 return {};
1297 return DenseElementsAttr::get(returnTy, maybeResult.value());
1298 }
1299 }
1300
1301 if (foldDenseValues) {
1302 assert(lETy.isIntOrIndex() &&
1303 "Only integer types are currently supported.");
1304 SmallVector<APInt> resultValues;
1305 for (auto [l, r] :
1306 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
1307 const auto maybeResult = Folder::fold(l, r, false);
1308 if (failed(maybeResult))
1309 return {};
1310 resultValues.push_back(maybeResult.value());
1311 }
1312 return DenseElementsAttr::get(returnTy, resultValues);
1313 }
1314
1315 return {};
1316}
1317
1318template <typename Folder>
1319static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
1320 bool foldDenseValues = false) {
1321 if (!val)
1322 return {};
1323
1324 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1325 return {};
1326
1327 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
1328
1329 if (val.isSplat()) {
1330 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1331 const APInt v = val.getSplatValue<APInt>();
1332 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1333 if (failed(maybeResult))
1334 return {};
1335 return DenseElementsAttr::get(returnTy, maybeResult.value());
1336 }
1337 }
1338
1339 if (foldDenseValues) {
1340 mlir::Type elemTy = val.getElementType();
1341 if (elemTy.isIntOrIndex()) {
1342 SmallVector<APInt> resultValues;
1343 for (auto const &v : val.getValues<APInt>()) {
1344 const auto maybeResult = Folder::fold(v, false);
1345 if (failed(maybeResult))
1346 return {};
1347 resultValues.push_back(maybeResult.value());
1348 }
1349 return DenseElementsAttr::get(returnTy, resultValues);
1350 }
1351 }
1352
1353 // Folding arbitrarily sized tensor operations is not supported
1354 return {};
1355}
1356
1357static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
1358 DenseIntElementsAttr dense{};
1359 if (!matchPattern(v, m_Constant(&dense)))
1360 return failure();
1361
1362 assert(dense.isSplat());
1363 APInt a = dense.getSplatValue<APInt>();
1364 return a.getSExtValue();
1365}
1366
1368 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1369 const bool isUnsigned) {
1370 bool overflow;
1371 const APInt result =
1372 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1373 if (overflow)
1374 return failure();
1375 return result;
1376 }
1377
1378 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1379 return lhs + rhs;
1380 }
1381};
1382
1384 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1385 const bool isUnsigned) {
1386 bool overflow;
1387 const APInt result =
1388 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1389 if (overflow)
1390 return failure();
1391 return result;
1392 }
1393
1394 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1395 return lhs - rhs;
1396 }
1397};
1398
1400 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1401 const bool isUnsigned) {
1402
1403 const unsigned originalWidth = lhs.getBitWidth();
1404
1405 // Check same type
1406 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1407 return failure();
1408 }
1409
1410 // If either is `0`
1411 if (lhs == 0 || rhs == 0)
1412 return APInt::getZero(originalWidth);
1413
1414 bool overflow = false;
1415 APInt const result =
1416 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1417
1418 if (overflow)
1419 return failure();
1420
1421 return result.trunc(originalWidth);
1422 }
1423
1424 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1425 return lhs * rhs;
1426 }
1427};
1428
1429static bool signsDiffer(const APInt &a, const APInt &b) {
1430 return a.isNegative() != b.isNegative();
1431}
1432
1433template <bool Ceil>
1435 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1436 bool isUnsigned) {
1437 if (lhs.getBitWidth() != rhs.getBitWidth())
1438 return failure();
1439 if (rhs.isZero())
1440 return failure();
1441
1442 if (isUnsigned) {
1443 APInt q{};
1444 APInt r{};
1445 APInt::udivrem(lhs, rhs, q, r);
1446 if (!r.isZero() && Ceil) {
1447 return q + 1;
1448 }
1449 return q;
1450 }
1451
1452 // Signed: start from trunc-toward-zero, then adjust to ceil.
1453 bool overflow{false};
1454 APInt const q = lhs.sdiv_ov(rhs, overflow);
1455 if (overflow)
1456 return failure();
1457 APInt const r = lhs.srem(rhs);
1458
1459 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1460 // Same sign => exact quotient is positive; trunc is below ceil =>
1461 // increment q.
1462 return q + 1;
1463 }
1464 return q;
1465 }
1466
1467 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1468 return lhs / rhs;
1469 }
1470};
1471
1473 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1474 bool isUnsigned) {
1475 if (lhs.getBitWidth() != rhs.getBitWidth())
1476 return failure();
1477 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1478 return failure();
1479
1480 if (isUnsigned) {
1481 return lhs.urem(rhs);
1482 }
1483
1484 return lhs.srem(rhs);
1485 }
1486
1487 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1488 auto t = lhs;
1489 auto const r = t.mod(rhs);
1490 if (llvm::APFloatBase::opStatus::opOK == r) {
1491 return t;
1492 }
1493 return failure();
1494 }
1495};
1496
1498 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1499 bool isUnsigned) {
1500 if (lhs.getBitWidth() != rhs.getBitWidth())
1501 return failure();
1502 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1503 }
1504
1505 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1506 return lhs >= rhs ? lhs : rhs;
1507 }
1508};
1509
1511 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1512 bool isUnsigned) {
1513 if (lhs.getBitWidth() != rhs.getBitWidth())
1514 return failure();
1515 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1516 }
1517
1518 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1519 return lhs <= rhs ? lhs : rhs;
1520 }
1521};
1522
1524 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1525 auto const numBits = value.getBitWidth();
1526 if (isUnsigned) {
1527 auto const zextv = value.getZExtValue();
1528 if (zextv >= numBits)
1529 return failure();
1530 return APInt::getOneBitSet(numBits, zextv);
1531 }
1532 auto const sextv = value.getSExtValue();
1533 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1534 return failure();
1535 return APInt::getOneBitSet(numBits, sextv);
1536 }
1537};
1538
1539// The specification requires shape div operations to have non-negative lhs and
1540// strictly positive rhs so we can only fold when these conditions are met.
1541template <bool Ceil>
1543 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1544 bool isUnsigned) {
1545 assert(!isUnsigned &&
1546 "unsigned values are not supported for shape div folders");
1547 if (lhs.isNegative() || !rhs.isStrictlyPositive())
1548 return failure();
1549 return DivFoldAdaptor<Ceil>::fold(lhs, rhs, isUnsigned);
1550 }
1551
1552 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1553 return failure();
1554 }
1555};
1556
1558 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1559 if (!value.isStrictlyPositive())
1560 return failure();
1561 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1562 }
1563};
1564
1566 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1567 if (!value.isStrictlyPositive())
1568 return failure();
1569 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1570 }
1571};
1572
1574 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1575 const bool isUnsigned) {
1576 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1577 }
1578
1579 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1580 return APInt(1, lhs > rhs);
1581 }
1582};
1583
1585 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1586 const bool isUnsigned) {
1587 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1588 }
1589
1590 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1591 return APInt(1, lhs >= rhs);
1592 }
1593};
1594
1596 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1597 const bool isUnsigned) {
1598 return APInt(1, lhs == rhs);
1599 }
1600
1601 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1602 return APInt(1, lhs == rhs);
1603 }
1604};
1605
1606static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1607 if (llvm::isa<FloatType>(elemType))
1608 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1609 if (llvm::isa<IntegerType>(elemType))
1610 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1611 return false;
1612}
1613
1614static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1615 if (llvm::isa<FloatType>(elemType))
1616 return val && val.isSplat() &&
1617 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1618 if (llvm::isa<IntegerType>(elemType)) {
1619 const int64_t shifted = 1LL << shift;
1620 return val && val.isSplat() &&
1621 val.getSplatValue<APInt>().getSExtValue() == shifted;
1622 }
1623 return false;
1624}
1625
1626OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1627 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1628 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1629 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1630 if (!lhsTy || !rhsTy || !resultTy)
1631 return {};
1632
1633 // Cannot create an ElementsAttr from non-int/float/index types
1634 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1635 !rhsTy.getElementType().isIntOrIndexOrFloat())
1636 return {};
1637
1638 auto resultETy = resultTy.getElementType();
1639 auto lhsAttr =
1640 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1641 auto rhsAttr =
1642 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1643
1644 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1645 lhsTy.getShape(), rhsTy.getShape());
1646 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1647 return getInput1();
1648 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1649 return getInput2();
1650
1651 if (!lhsAttr || !rhsAttr)
1652 return {};
1653
1654 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1655}
1656
1657OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1658 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1659 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1660 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1661 !outputTy.hasStaticShape())
1662 return {};
1663
1664 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1665 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1666 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1667 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1668 return DenseElementsAttr::get(outputTy, zero);
1669 }
1670
1671 return {};
1672}
1673
1674OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1675 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1676 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1677 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1678 if (!lhsTy || !rhsTy || !resultTy)
1679 return {};
1680 if (lhsTy.getElementType() != rhsTy.getElementType())
1681 return {};
1682
1683 // IntDivOp inputs must be integer type, no need to check for quantized
1684 // type
1685 auto resultETy = resultTy.getElementType();
1686 auto lhsAttr =
1687 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1688 auto rhsAttr =
1689 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1690 if (lhsAttr && lhsAttr.isSplat() && rhsAttr && rhsAttr.isSplat()) {
1691 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1692 lhsAttr.getSplatValue<APInt>().isZero() &&
1693 !rhsAttr.getSplatValue<APInt>().isZero()) {
1694 return lhsAttr.resizeSplat(resultTy);
1695 }
1696 }
1697
1698 if (rhsAttr && rhsAttr.isSplat()) {
1699 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1700 lhsTy.getShape(), rhsTy.getShape());
1701 if (isBroadcastable && lhsTy == resultTy &&
1702 llvm::isa<IntegerType>(resultETy) &&
1703 rhsAttr.getSplatValue<APInt>().isOne())
1704 return getInput1();
1705 }
1706
1707 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1708 llvm::isa<IntegerType>(resultETy)) {
1709 APInt l = lhsAttr.getSplatValue<APInt>();
1710 APInt r = rhsAttr.getSplatValue<APInt>();
1711 if (!r.isZero()) {
1712 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1713 auto const result =
1714 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1715 if (failed(result))
1716 return {};
1717 return DenseElementsAttr::get(resultTy, result.value());
1718 }
1719 }
1720
1721 return {};
1722}
1723
1724namespace {
1725// calculate lhs * rhs >> shift according to TOSA Spec
1726// return nullopt if result is not in range of int32_t when shift > 0
1727std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1728 unsigned bitwidth) {
1729 bool overflow = false;
1730 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1731
1732 if (overflow)
1733 return std::nullopt;
1734
1735 if (shift > 0) {
1736 auto round = APInt(64, 1) << (shift - 1);
1737 result += round;
1738 result.ashrInPlace(shift);
1739 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1740 // maximum_s<i32_t>())
1741 if (!(result.getSExtValue() >= INT32_MIN &&
1742 result.getSExtValue() <= INT32_MAX)) {
1743 // REQUIRE failed
1744 return std::nullopt;
1745 }
1746 }
1747
1748 return result.trunc(bitwidth);
1749}
1750
1751DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1752 RankedTensorType ty, int32_t shift) {
1753 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1754 if (llvm::isa<IntegerType>(ty.getElementType())) {
1755 APInt l = lhs.getSplatValue<APInt>();
1756 APInt r = rhs.getSplatValue<APInt>();
1757
1758 if (shift == 0) {
1759 return DenseElementsAttr::get(ty, l * r);
1760 }
1761
1762 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1763 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1764 if (!result)
1765 return {};
1766 return DenseElementsAttr::get(ty, result.value());
1767 }
1768
1769 if (llvm::isa<FloatType>(ty.getElementType())) {
1770 APFloat l = lhs.getSplatValue<APFloat>();
1771 APFloat r = rhs.getSplatValue<APFloat>();
1772 APFloat result = l * r;
1773 return DenseElementsAttr::get(ty, result);
1774 }
1775 }
1776
1777 return {};
1778}
1779} // namespace
1780
1781OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1782 auto lhs = getInput1();
1783 auto rhs = getInput2();
1784 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1785 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1786 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1787 if (!lhsTy || !rhsTy || !resultTy)
1788 return {};
1789
1790 auto resultETy = resultTy.getElementType();
1791 auto lhsAttr =
1792 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1793 auto rhsAttr =
1794 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1795
1796 // Result right shift on i32_t data type only. For simplification,
1797 // synthesize a zero shift for other data type.
1798 int32_t shift = 0;
1799 if (resultETy.isInteger(32)) {
1800 ElementsAttr shift_elem;
1801 if (getShift().getImpl()) {
1802 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1803 // cannot be folded when the shift value is unknown.
1804 return {};
1805 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1806 }
1807 }
1808
1809 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1810 resultTy.hasStaticShape())
1811 // constant values can only be resized if resulting type is static
1812 return lhsAttr.resizeSplat(resultTy);
1813 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1814 resultTy.hasStaticShape())
1815 return rhsAttr.resizeSplat(resultTy);
1816
1817 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1818 lhsTy.getShape(), rhsTy.getShape());
1819 if (isBroadcastable && rhsTy == resultTy &&
1820 isSplatOne(resultETy, lhsAttr, shift))
1821 return rhs;
1822 if (isBroadcastable && lhsTy == resultTy &&
1823 isSplatOne(resultETy, rhsAttr, shift))
1824 return lhs;
1825
1826 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1827}
1828
1829OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1830 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1831 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1832 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1833 if (!lhsTy || !rhsTy || !resultTy)
1834 return {};
1835
1836 // Cannot create an ElementsAttr from non-int/float/index types
1837 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1838 !rhsTy.getElementType().isIntOrIndexOrFloat())
1839 return {};
1840
1841 auto resultETy = resultTy.getElementType();
1842 auto lhsAttr =
1843 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1844 auto rhsAttr =
1845 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1846
1847 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1848 lhsTy.getShape(), rhsTy.getShape());
1849 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1850 return getInput1();
1851
1852 if (!lhsAttr || !rhsAttr)
1853 return {};
1854
1855 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1856}
1857
1858OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1859 auto resultTy = llvm::cast<ShapedType>(getType());
1860 auto lhsAttr =
1861 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1862 auto rhsAttr =
1863 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1864
1865 if (!lhsAttr || !rhsAttr)
1866 return {};
1867
1868 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1869}
1870
1871OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1872 auto resultTy = llvm::cast<ShapedType>(getType());
1873 auto lhsAttr =
1874 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1875 auto rhsAttr =
1876 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1877
1878 if (!lhsAttr || !rhsAttr)
1879 return {};
1880
1881 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1882}
1883
1884OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1885 auto resultTy = llvm::cast<ShapedType>(getType());
1886 auto lhsAttr =
1887 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1888 auto rhsAttr =
1889 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1890 Value lhs = getInput1();
1891 Value rhs = getInput2();
1892 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1893
1894 // If we are comparing an integer value to itself it is always true. We
1895 // can not do this with float due to float values.
1896 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1897 resultTy.hasStaticShape() && lhs == rhs) {
1898 return DenseElementsAttr::get(resultTy, true);
1899 }
1900
1901 if (!lhsAttr || !rhsAttr)
1902 return {};
1903
1904 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1905}
1906
1907OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1908 if (getInput().getType() == getType())
1909 return getInput();
1910
1911 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1912 if (!operand)
1913 return {};
1914
1915 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1916 auto outTy = llvm::cast<ShapedType>(getType());
1917 if (!outTy.hasRank() || !outTy.hasStaticShape())
1918 return {};
1919 auto inETy = inTy.getElementType();
1920 auto outETy = outTy.getElementType();
1921
1922 if (operand.isSplat()) {
1923 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1924 bool overflow;
1925 auto splatVal = operand.getSplatValue<APFloat>();
1926 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1927 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1928 &overflow);
1929 return SplatElementsAttr::get(outTy, splatVal);
1930 }
1931
1932 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1933 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1934 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1935 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1936 llvm::RoundingMode::NearestTiesToEven);
1937 return SplatElementsAttr::get(outTy, splatVal);
1938 }
1939
1940 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1941 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1942 auto intVal = APSInt(
1943 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1944 auto floatVal = operand.getSplatValue<APFloat>();
1945 bool exact;
1946 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1947 &exact);
1948 return SplatElementsAttr::get(outTy, intVal);
1949 }
1950
1951 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1952 const auto inIntType = llvm::cast<IntegerType>(inETy);
1953 auto unsignIn = inIntType.isUnsignedInteger();
1954 bool trunc =
1955 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1956 auto intVal = operand.getSplatValue<APInt>();
1957 auto bitwidth = outETy.getIntOrFloatBitWidth();
1958
1959 // i1 types are boolean in TOSA
1960 if (outETy.isInteger(1)) {
1961 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1962 } else if (trunc) {
1963 intVal = intVal.trunc(bitwidth);
1964 } else if (unsignIn || inIntType.isInteger(1)) {
1965 intVal = intVal.zext(bitwidth);
1966 } else {
1967 intVal = intVal.sext(bitwidth);
1968 }
1969
1970 return SplatElementsAttr::get(outTy, intVal);
1971 }
1972 }
1973
1974 return {};
1975}
1976
1977OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1978
1979OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1980
1981#define REDUCE_FOLDER(OP) \
1982 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1983 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1984 if (!inputTy.hasRank()) \
1985 return {}; \
1986 if (inputTy != getType()) \
1987 return {}; \
1988 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1989 return getInput(); \
1990 return {}; \
1991 }
1992
1993REDUCE_FOLDER(ReduceAllOp)
1994REDUCE_FOLDER(ReduceAnyOp)
1995REDUCE_FOLDER(ReduceMaxOp)
1996REDUCE_FOLDER(ReduceMinOp)
1997REDUCE_FOLDER(ReduceProductOp)
1998REDUCE_FOLDER(ReduceSumOp)
1999#undef REDUCE_FOLDER
2000
2001OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
2002 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2003 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
2004
2005 if (!inputTy || !outputTy)
2006 return {};
2007
2008 // Fold when the input and output types are the same. This is only safe
2009 // when there is at most 1 dynamic dimension. For 2 or more dynamic
2010 // dimensions, there may still be a productive reshape.
2011 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
2012 return getInput1();
2013
2014 // reshape(reshape(x)) -> reshape(x)
2015 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
2016 getInput1().getDefiningOp())) {
2017 getInput1Mutable().assign(reshapeOp.getInput1());
2018 return getResult();
2019 }
2020
2021 // Cannot create an ElementsAttr from non-int/float/index types
2022 if (!inputTy.getElementType().isIntOrIndexOrFloat())
2023 return {};
2024
2025 // reshape(const(x)) -> const(reshape-attr(x))
2026 if (auto operand =
2027 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2028 // Constants must have static shape.
2029 if (!outputTy.hasStaticShape())
2030 return {};
2031
2032 // Okay to duplicate splat constants.
2033 if (operand.isSplat())
2034 return SplatElementsAttr::get(outputTy,
2035 operand.getSplatValue<Attribute>());
2036
2037 // Don't duplicate other constants.
2038 if (!getInput1().hasOneUse())
2039 return {};
2040
2042 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
2043 return {};
2044
2045 return operand.reshape(
2046 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
2047 }
2048
2049 return {};
2050}
2051
2052OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
2053 // If the pad is all zeros we can fold this operation away.
2054 if (adaptor.getPadding() && getInput1().getType() == getType()) {
2055 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
2056 if (densePad && densePad.isSplat() &&
2057 densePad.getSplatValue<APInt>().isZero()) {
2058 return getInput1();
2059 }
2060 }
2061
2062 return {};
2063}
2064
2065// Fold away cases where a tosa.resize operation returns a copy
2066// of the input image.
2067OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
2068 auto scaleAttr =
2069 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
2070 auto offsetAttr =
2071 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
2072 auto borderAttr =
2073 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
2074 if (!scaleAttr || !offsetAttr || !borderAttr) {
2075 return {};
2076 }
2077
2078 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
2079 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
2080 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
2081 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
2082 return {};
2083 }
2084
2085 // Check unit scaling.
2086 if (scale[0] != scale[1] || scale[2] != scale[3]) {
2087 return {};
2088 }
2089
2090 // There should be no offset.
2091 if (offset[0] != 0 || offset[1] != 0) {
2092 return {};
2093 }
2094
2095 // There should be no border.
2096 if (border[0] != 0 || border[1] != 0) {
2097 return {};
2098 }
2099
2100 return foldToInputIfTypeMatches(getType(), getInput());
2101}
2102
2103OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2104 auto operand = getInput1();
2105 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2106 auto axis = getAxis();
2107 // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
2108 const bool isSplatInput =
2109 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2110 if (!operandTy.hasRank() ||
2111 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2112 return {};
2113 return foldToInputIfTypeMatches(getType(), operand);
2114}
2115
2116OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2117 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2118 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
2119
2120 if (!inputTy || !outputTy)
2121 return {};
2122
2123 if (inputTy == outputTy && inputTy.hasStaticShape())
2124 return getInput1();
2125
2126 // Check if this is a no-op slice (starts at 0 and size matches input)
2127
2128 DenseElementsAttr startElems;
2129 if (!matchPattern(getStart(), m_Constant(&startElems)))
2130 return {};
2131
2132 // Check if all start values are zero
2133 bool startIsZeros =
2134 llvm::all_of(startElems.getValues<APInt>(),
2135 [](const APInt &val) { return val.isZero(); });
2136
2137 if (startIsZeros) {
2138
2139 // Check if size matches input shape
2140 DenseElementsAttr sizeElems;
2141 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
2142 return {};
2143
2144 auto inputShape = inputTy.getShape();
2145 auto sizeValues = sizeElems.getValues<APInt>();
2146
2147 bool sizeMatchesInput = true;
2148 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2149 int64_t size = sizeVal.getSExtValue();
2150
2151 if (inputTy.isDynamicDim(i)) {
2152 // For dynamic dimensions, check for kInferableDimSize indicating full
2153 // dimension is sliced
2154 if (size != kInferableDimSize) {
2155 sizeMatchesInput = false;
2156 break;
2157 }
2158 } else {
2159 // For static dimensions, check that size must match exactly or be
2160 // kInferableDimSize indicating full dimension is sliced
2161 if (size != kInferableDimSize && size != inputShape[i]) {
2162 sizeMatchesInput = false;
2163 break;
2164 }
2165 }
2166 }
2167
2168 if (sizeMatchesInput)
2169 return getInput1();
2170 }
2171
2172 // The following checks require the input to be a constant
2173 if (!adaptor.getInput1())
2174 return {};
2175
2176 // Cannot create an ElementsAttr from non-int/float/index types
2177 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2178 !outputTy.getElementType().isIntOrIndexOrFloat())
2179 return {};
2180
2181 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2182 if (operand.isSplat() && outputTy.hasStaticShape()) {
2183 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
2184 }
2185
2186 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2187 outputTy.getNumElements() == 1) {
2188 llvm::SmallVector<uint64_t> indices =
2189 llvm::to_vector(startElems.getValues<uint64_t>());
2190 if (auto values = operand.tryGetValues<Attribute>())
2191 return SplatElementsAttr::get(outputTy, (*values)[indices]);
2192 }
2193
2194 return {};
2195}
2196
2197OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2198 const Value pred = getPred();
2199 const Value onTrue = getOnTrue();
2200 const Value onFalse = getOnFalse();
2201
2202 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
2203 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
2204 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
2205 if (!predTy || !onTrueTy || !onFalseTy)
2206 return {};
2207
2208 const Type resultTy = getType();
2209
2210 const ArrayRef<int64_t> predShape = predTy.getShape();
2211 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2212
2213 if (onTrue == onFalse && onTrueTy == resultTy &&
2214 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
2215 return onTrue;
2216
2217 auto predicate =
2218 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2219 if (!predicate)
2220 return {};
2221 if (!predicate.isSplat())
2222 return {};
2223
2224 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2225
2226 SmallVector<SmallVector<int64_t>, 3> shapes;
2227 shapes.emplace_back(predShape);
2228 shapes.emplace_back(onTrueShape);
2229 shapes.emplace_back(onFalseTy.getShape());
2230 const bool isBroadcastable =
2232
2233 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
2234 return onTrue;
2235 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
2236 return onFalse;
2237 return {};
2238}
2239
2240OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2241 if (getInput1().getType() == getType()) {
2242 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2243 adaptor.getMultiples())) {
2244 if (multiples.isSplat() &&
2245 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2246 return getInput1();
2247 if (auto int_array_attr =
2248 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2249 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2250 [](APInt v) { return v.getSExtValue() == 1; }))
2251 return getInput1();
2252 }
2253 }
2254 }
2255 return {};
2256}
2257
2258OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2259 auto resultTy = llvm::cast<ShapedType>(getType());
2260
2261 // Transposing splat values just means reshaping.
2262 if (auto input =
2263 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2264 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2265 input.getType().getElementType() == resultTy.getElementType())
2266 return input.reshape(resultTy);
2267 }
2268
2269 // Transpose is not the identity transpose.
2270 const llvm::ArrayRef<int32_t> perms = getPerms();
2271
2272 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2273 return {};
2274
2275 return foldToInputIfTypeMatches(getType(), getInput1());
2276}
2277
2278OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2279 // Element-wise negate(negate(x)) = x
2280 // iff all zero points are constant 0
2281 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2282 if (!definingOp) {
2283 // defining op of input1 is not a negate, cannot fold
2284 return {};
2285 }
2286
2287 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2288 failed(maybeIZp) || *maybeIZp != 0) {
2289 // input1 zero point is not constant 0, cannot fold
2290 return {};
2291 }
2292 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2293 failed(maybeOZp) || *maybeOZp != 0) {
2294 // output zero point is not constant 0, cannot fold
2295 return {};
2296 }
2297 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2298 failed(maybeIZp) || *maybeIZp != 0) {
2299 // definingOp's input1 zero point is not constant 0, cannot fold
2300 return {};
2301 }
2302 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2303 failed(maybeOZp) || *maybeOZp != 0) {
2304 // definingOp's output zero point is not constant 0, cannot fold
2305 return {};
2306 }
2307
2308 return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
2309}
2310
2311OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2312 auto input = getInput1();
2313 // Element-wise abs(abs(x)) = abs(x)
2314 if (input.getDefiningOp<tosa::AbsOp>())
2315 return foldToInputIfTypeMatches(getType(), input);
2316
2317 return {};
2318}
2319
2320OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2321 auto input = adaptor.getInput1();
2322
2323 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2324 // Fold splat inputs only.
2325 if (!inputAttr || !inputAttr.isSplat())
2326 return {};
2327
2328 auto shapeType = llvm::cast<ShapedType>(getType());
2329 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2330 return {};
2331 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2332 auto floatVal = inputAttr.getSplatValue<APFloat>();
2333 return DenseElementsAttr::get(shapeType,
2334 ReciprocalOp::calcOneElement(floatVal));
2335 }
2336
2337 return {};
2338}
2339
2340template <typename Op, typename OpFoldAdaptor>
2342 auto input1ConstShape =
2343 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2344 if (!input1ConstShape)
2345 return {};
2346
2347 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2348
2349 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2350 /*foldDenseValues=*/true);
2351}
2352
2353template <typename Op, typename OpFoldAdaptor>
2355 auto input1ConstShape =
2356 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2357 auto input2ConstShape =
2358 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2359 if (!input1ConstShape || !input2ConstShape)
2360 return {};
2361
2362 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2363 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2364
2365 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2366 input1Attr.getType(),
2367 /*foldDenseValues=*/true);
2368}
2369
2370OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2371 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2372 if (!inputTy || !inputTy.hasRank())
2373 return {};
2374 const int32_t axis = getAxis();
2375 const int64_t dimSize = inputTy.getDimSize(axis);
2376 if (ShapedType::isDynamic(dimSize))
2377 return {};
2378
2379 OpBuilder builder(getContext());
2380 const auto resultAttrTy =
2381 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2382 return DenseElementsAttr::get(resultAttrTy, dimSize);
2383}
2384
2385OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
2386 auto const inputs = op->getInput();
2387
2388 if (inputs.empty())
2389 return {};
2390
2391 SmallVector<APInt> concatDims;
2392 concatDims.reserve(/*max elem*/ 64);
2393 for (auto const &v : inputs) {
2394 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2395 if (!vConstShape)
2396 return {};
2397
2398 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2399 assert(vAttr);
2400
2401 auto const vAttrVals = vAttr.getValues<APInt>();
2402 for (auto const &v : vAttrVals) {
2403 concatDims.push_back(v);
2404 }
2405 }
2406
2407 auto *ctx = op->getContext();
2408 assert(ctx != nullptr && "ctx is nullptr");
2409 auto const rankedTy = RankedTensorType::get(
2410 {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2411
2412 return DenseElementsAttr::get(rankedTy, concatDims);
2413}
2414
2415OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
2416 auto const input1 = op->getInput();
2417 auto const input2 = op->getStart();
2418 auto const input3 = op->getSize();
2419
2420 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2421
2422 if (!input1ConstShape)
2423 return {};
2424
2425 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2426 if (!input1Attr)
2427 return {};
2428
2429 auto const input1Vals = input1Attr.getValues<APInt>();
2430 auto const totalInput1 = input1Vals.size();
2431
2432 auto const start = getSingleI64From1ElementTensor(input2);
2433 auto const size = getSingleI64From1ElementTensor(input3);
2434
2435 if (failed(start) || failed(size))
2436 return {};
2437
2438 auto const startV = static_cast<int32_t>(start.value());
2439 auto const sizeV = static_cast<int32_t>(size.value());
2440
2441 if ((sizeV <= 0) || (startV < 0) ||
2442 (static_cast<size_t>(startV + sizeV) > totalInput1))
2443 return {};
2444
2445 SmallVector<APInt> sliceOfInput;
2446 sliceOfInput.reserve(totalInput1);
2447
2448 for (auto i = startV; i < (startV + sizeV); i++) {
2449 sliceOfInput.push_back(input1Vals[i]);
2450 }
2451
2452 auto *ctx = op->getContext();
2453 assert(ctx != nullptr && "ctx is nullptr");
2454
2455 auto const rankedTy = RankedTensorType::get(
2456 {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2457
2458 return DenseElementsAttr::get(rankedTy, sliceOfInput);
2459}
2460
2461OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2463}
2464
2465OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2467}
2468
2469OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2471}
2472
2473OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2474 return binaryFold<DivCeilShapeOp, ShapeDivFoldAdaptor</*Ceil*/ true>>(this);
2475}
2476
2477OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2478 return binaryFold<DivFloorShapeOp, ShapeDivFoldAdaptor</*Ceil*/ false>>(this);
2479}
2480
2481OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2483}
2484
2485OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2487}
2488
2489OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2491}
2492
2493OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2495}
2496
2497OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2499}
2500
2501OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2503}
2504
2505OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2506 return concatShapeFold(this);
2507}
2508
2509OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
2510 return sliceShapeFold(this);
2511}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
TosaLevel getTosaLevelFromEnum(const Level level)
Definition TargetEnv.cpp:15
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:106
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
TargetEnvAttr lookupTargetEnv(Operation *op)
FailureOr< T > getConstantScalarIntValue(Value val)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::RowGatherOp op, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
int32_t MAX_TENSOR_LIST_SIZE
Definition TargetEnv.h:30