MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
19#include "mlir/Dialect/Traits.h"
22#include "mlir/IR/Matchers.h"
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28
29#include <functional>
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34namespace {
35OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
36 return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
37}
38} // namespace
39
40//===----------------------------------------------------------------------===//
41// Operator Canonicalizers.
42//===----------------------------------------------------------------------===//
43
44//===----------------------------------------------------------------------===//
45// Tensor Data Engine Operators.
46//===----------------------------------------------------------------------===//
47
48// Check that the zero point of the tensor and padding operations are aligned.
49static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
50 // Check that padConst is a constant value and a scalar tensor
51 DenseElementsAttr padConstAttr;
52 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
53 (padConstAttr.size() != 1)) {
54 return false;
55 }
56
57 // Check that floating point pad is zero
58 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
61 }
62
63 // Check that the zp and padConst align for the integer (quantized) case
64 if (auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 // Check that zp is a constant value and a scalar tensor
68 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
69 return false;
70 }
71
72 // Check equality
73 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
76 }
77
78 // Bail-out on unsupported type
79 return false;
80}
81
82namespace {
83template <typename OpTy>
84struct PoolPadFoldAdaptor;
85
86template <>
87struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
93 return false;
94 return true;
95 }
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
97 // Check that padConst is a constant value and a scalar tensor
98 DenseElementsAttr padConstAttr;
99 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
100 padConstAttr.size() != 1) {
101 return false;
102 }
103
104 // Pad needs to be in the minimum value to be able to merge
105 if (auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(), true);
110 return padConstVal == lowestVal;
111 }
112 if (auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
121 }
122
123 // Bail-out on unsupported type
124 return false;
125 }
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
128 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
130 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
131 }
132};
133
134template <typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
137 return true;
138 }
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
140 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
141 }
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
144 rewriter.replaceOpWithNewOp<OpTy>(
145 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148 }
149};
150
151// Pattern attempts to fold a `tosa.pad` operator to a following tensor
152// operation like `tosa.conv2d` by merging the padding associated with the
153// pad operator directly to the implicit padding of the tensor operation.
154// This helps eliminate the explicit padding operator if unused.
155template <typename OpTy, typename AdaptorTy>
156struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
157 using OpRewritePattern<OpTy>::OpRewritePattern;
158
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter) const override {
161 // Check producer is a tosa::PadOp
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
163 if (!padOp)
164 return rewriter.notifyMatchFailure(tensorOp,
165 "Producer must be a tosa::PadOp.");
166
167 // Validate that tensor operation has sane padding
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
170 return rewriter.notifyMatchFailure(
171 tensorOp, "Tensor operation padding shall have 4 elements.");
172
173 // Validate tosa::PadOp padding
174 DenseIntElementsAttr padOpPadding;
175 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
176 return rewriter.notifyMatchFailure(
177 tensorOp,
178 "The `padding` input specified on the tosa::PadOp must be constant.");
179 }
180 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
181 // C_after
182 if (padOpPadding.size() != 8)
183 return rewriter.notifyMatchFailure(tensorOp,
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
193
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
195 return rewriter.notifyMatchFailure(
196 tensorOp, "Folding padding in N or C dimensions is not supported.");
197
198 // Fold padding from Pad into the tensor operation
199 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
205
206 // Check kernel related restrictions
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
208 return rewriter.notifyMatchFailure(
209 tensorOp, "Padding size not aligned with kernel restrictions.");
210 }
211
212 // Check padding constant restrictions
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
214 return rewriter.notifyMatchFailure(
215 tensorOp,
216 "Padding constant is not aligned with operator zero-point.");
217 }
218
219 // Check that padding doesn't grow more than 8K level (8192) for now
220 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
221 return rewriter.notifyMatchFailure(
222 tensorOp, "Padding size more than the 8K level limit.");
223 }
224
225 // Create operator
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
227 foldedPad);
228
229 return success();
230 }
231};
232} // namespace
233
234void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
238 context);
239}
240
241void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
242 MLIRContext *context) {
243 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
245 context);
246}
247
249 : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
251
252 LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
253 PatternRewriter &rewriter) const override {
257 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
258 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
259 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
260 return rewriter.notifyMatchFailure(
261 op, "expected constant kernel, stride, and pad operands");
262
263 auto replacement = tosa::AvgPool2dOp::create(
264 rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
265 op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
266 rewriter.getDenseI64ArrayAttr(stride),
267 rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
268 rewriter.replaceOp(op, replacement.getOutput());
269 return success();
270 }
271};
272
273void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
274 RewritePatternSet &results, MLIRContext *context) {
275 results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
276}
277
278struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
280
281 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
282 PatternRewriter &rewriter) const override {
283 Value input = op.getInput();
284 Value output = op.getOutput();
285 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
286 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
287
288 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
289 return failure();
290 }
291
292 // If the output and input shapes are 1x1, then this is a no op.
293 ArrayRef<int64_t> outputShape = outputType.getShape();
294 if (outputShape[1] != 1 || outputShape[2] != 1) {
295 return failure();
296 }
297
298 ArrayRef<int64_t> inputShape = inputType.getShape();
299 if (inputShape[1] != 1 || inputShape[2] != 1) {
300 return failure();
301 }
302
303 rewriter.replaceOp(op, input);
304 return success();
305 }
306};
307
308void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<MaxPool2dIsNoOp,
311 FoldPadToTensorOp<tosa::MaxPool2dOp,
312 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
313 context);
314}
315
317 : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
319
320 LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
321 PatternRewriter &rewriter) const override {
325 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
326 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
327 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
328 return rewriter.notifyMatchFailure(
329 op, "expected constant kernel, stride, and pad operands");
330
331 auto replacement = tosa::MaxPool2dOp::create(
332 rewriter, op.getLoc(), op.getType(), op.getInput(),
333 rewriter.getDenseI64ArrayAttr(kernel),
334 rewriter.getDenseI64ArrayAttr(stride),
335 rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
336 rewriter.replaceOp(op, replacement.getOutput());
337 return success();
338 }
339};
340
341void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
342 RewritePatternSet &results, MLIRContext *context) {
343 results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
344}
345
346//===----------------------------------------------------------------------===//
347// Data Layout / Memory Reinterpretation.
348//===----------------------------------------------------------------------===//
349
350struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
351 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
352
353 LogicalResult matchAndRewrite(tosa::ConcatOp op,
354 PatternRewriter &rewriter) const override {
355 if (op.getInput1().size() != 1)
356 return failure();
357 if (op.getInput1().front().getType() != op.getType()) {
358 rewriter
359 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
360 op.getInput1().front())
361 .getResult();
362 return success();
363 }
364
365 rewriter.replaceOp(op, op.getInput1().front());
366 return success();
367 }
368};
369
370void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
371 MLIRContext *context) {
372 results.add<ConcatOptimization>(context);
373}
374
375LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
376 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
377 if (!notOp)
378 return failure();
379 rewriter.modifyOpInPlace(op, [&]() {
380 op.getOperation()->setOperands(
381 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
382 });
383 return success();
384}
385
387 : public OpRewritePattern<tosa::TransposeOp> {
389
390 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
391 PatternRewriter &rewriter) const override {
392 // Input is also TransposeOp - transpose(transpose(A)).
393 auto innerTranspose =
394 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
395 if (!innerTranspose)
396 return rewriter.notifyMatchFailure(transposeOp,
397 "input must be transpose operation");
398
399 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
400 const llvm::ArrayRef<int32_t> innerTransposePerms =
401 innerTranspose.getPerms();
402
403 if (transposePerms.size() != innerTransposePerms.size())
404 return rewriter.notifyMatchFailure(
405 transposeOp,
406 "transpose and inner transpose perms sizes must be equal");
407 if (transposePerms.empty())
408 return rewriter.notifyMatchFailure(
409 transposeOp, "transpose perms sizes must be positive");
410
411 // Consolidate transposes into one transpose.
412 SmallVector<int32_t> perms(transposePerms.size());
413 for (int i = 0, s = transposePerms.size(); i < s; ++i)
414 perms[i] = innerTransposePerms[transposePerms[i]];
415
416 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
417 transposeOp, transposeOp.getResult().getType(),
418 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
419
420 return success();
421 }
422};
423
424// Determines the case when tosa.transpose is a tosa.reshape operation.
425struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
427
428 LogicalResult matchAndRewrite(tosa::TransposeOp op,
429 PatternRewriter &rewriter) const override {
430 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
431 return rewriter.notifyMatchFailure(
432 op, "Src is from transpose, can compose transposes");
433
434 Value result = op.getResult();
435 for (Operation *subop : result.getUsers()) {
436 if (isa_and_nonnull<tosa::TransposeOp>(subop))
437 return rewriter.notifyMatchFailure(
438 op, "Dest is used by transpose, can compose transposes");
439 }
440
441 auto input = op.getInput1();
442 auto inputTy = llvm::cast<ShapedType>(input.getType());
443 if (!inputTy.hasRank())
444 return rewriter.notifyMatchFailure(op, "Unranked input.");
445
446 int64_t numDynDims = 0;
447 for (int i = 0; i < inputTy.getRank(); ++i)
448 if (inputTy.isDynamicDim(i))
449 numDynDims++;
450
451 if (numDynDims > 1)
452 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
453
454 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
455
456 SmallVector<int64_t> nonZeroPerms;
457 nonZeroPerms.reserve(permValues.size());
458 for (auto idx : permValues) {
459 auto sz = inputTy.getDimSize(idx);
460 if (sz != 1)
461 nonZeroPerms.push_back(idx);
462 }
463
464 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
465 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
466 return rewriter.notifyMatchFailure(op,
467 "Transpose changes memory layout.");
468
469 SmallVector<int64_t> newShape;
470 newShape.reserve(inputTy.getRank());
471 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
472 newShape.push_back(inputTy.getDimSize(permValues[i]));
473
474 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
475 op, op.getType(), op.getInput1(),
476 getTosaConstShape(rewriter, op.getLoc(), newShape));
477 return success();
478 }
479};
480
481void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
482 MLIRContext *context) {
483 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
484}
485
486struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
488
489 LogicalResult matchAndRewrite(tosa::ClampOp op,
490 PatternRewriter &rewriter) const override {
491 Value input = op.getInput();
492 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
493 auto inputElementType = inputType.getElementType();
494
495 if (isa<FloatType>(inputElementType)) {
496 // Unlike integer types, floating point types can represent infinity.
497 const auto minClamp =
498 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
499 const auto maxClamp =
500 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
501 const bool isMin = minClamp.isNegInfinity();
502 const bool isMax = maxClamp.isInfinity();
503
504 if (isMin && isMax) {
505 rewriter.replaceOp(op, input);
506 return success();
507 }
508 return failure();
509 }
510
511 // i1 types are boolean in TOSA
512 const bool isBoolean = inputElementType.isInteger(1);
513 if (inputElementType.isUnsignedInteger() || isBoolean) {
514 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
515 .getValue()
516 .getZExtValue();
517 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
518 .getValue()
519 .getZExtValue();
520
521 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
522 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
523 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
524
525 if (minClamp <= intMin && maxClamp >= intMax) {
526 rewriter.replaceOp(op, input);
527 return success();
528 }
529 return failure();
530 }
531
532 if (llvm::isa<IntegerType>(inputElementType)) {
533 const int64_t minClamp =
534 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
535 const int64_t maxClamp =
536 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
537
538 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
539 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
540 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
541
542 if (minClamp <= intMin && maxClamp >= intMax) {
543 rewriter.replaceOp(op, input);
544 return success();
545 }
546 return failure();
547 }
548
549 return failure();
550 }
551};
552
553// Attempts the following transformation:
554//
555// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
556// tensor X the following identity holds:
557//
558// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
559//
560// subject to the following valid NaN propagation semantics:
561// --------------------------------------------
562// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
563// |-------------|--------------|-------------|
564// | PROPAGATE | PROPAGATE | PROPAGATE |
565// | PROPAGATE | IGNORE | IGNORE |
566// | IGNORE | PROPAGATE | INVALID |
567// | IGNORE | IGNORE | IGNORE |
568// |------------------------------------------|
569
570struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
571 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
572
573 // Helper structure to describe the range of a clamp operation.
574 template <typename T>
575 struct ClampRange {
576 ClampRange(const T &start, const T &end) : start(start), end(end) {}
579
580 // Helper function to determine if two Clamp ranges intersect.
581 bool intersects(const ClampRange<T> &otherRange) {
582 return start < otherRange.end && otherRange.start < end;
583 }
584 };
585
586 LogicalResult matchAndRewrite(tosa::ClampOp op,
587 PatternRewriter &rewriter) const override {
588 Value input = op.getInput();
589
590 // Check the input to the CLAMP op is itself a CLAMP.
591 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
592 if (!clampOp)
593 return failure();
594
595 // Check we have a valid NaN propagation combination.
596 const auto opNanMode = op.getNanMode();
597 const auto clampNanMode = clampOp.getNanMode();
598 if (opNanMode == NanPropagationMode::IGNORE &&
599 clampNanMode == NanPropagationMode::PROPAGATE)
600 return failure();
601
602 auto maxValAttr = op.getMaxValAttr();
603 auto minValAttr = op.getMinValAttr();
604 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
605 auto clampOpMinValAttr = clampOp.getMinValAttr();
606
607 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
608 if (auto quantType =
609 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
610 inputEType = getStorageElementTypeFromQuantized(quantType);
611 }
612
613 Attribute newMinValAttr, newMaxValAttr;
614 if (mlir::isa<FloatType>(inputEType)) {
615 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
616 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
617 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
618 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
619
620 // Check we have intersecting ranges.
621 const auto opMinFloat = floatMinValAttr.getValue();
622 const auto opMaxFloat = floatMaxValAttr.getValue();
623 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
624 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
625 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
626 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
627 clampOpMaxFloat);
628 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
629 return failure();
630
631 // Run the transformation.
632 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
633 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
634 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
635 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
636 } else {
637 assert(mlir::isa<IntegerType>(inputEType));
638 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
639 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
640 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
641 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
642
643 if (inputEType.isUnsignedInteger()) {
644 // Check we have intersecting ranges.
645 const auto opMinInt = intMinValAttr.getUInt();
646 const auto opMaxInt = intMaxValAttr.getUInt();
647 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
648 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
649 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
650 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
651 clampOpMaxInt);
652 if (!opRangeIntRange.intersects(clampRangeIntRange))
653 return failure();
654
655 // Run the transformation.
656 auto newMinVal = std::max(opMinInt, clampOpMinInt);
657 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
658 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
659 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
660 } else {
661 // Check we have intersecting ranges.
662 const auto opMinInt = intMinValAttr.getInt();
663 const auto opMaxInt = intMaxValAttr.getInt();
664 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
665 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
666 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
667 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
668 clampOpMaxInt);
669 if (!opRangeIntRange.intersects(clampRangeIntRange))
670 return failure();
671
672 // Run the transformation.
673 auto newMinVal = std::max(opMinInt, clampOpMinInt);
674 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
675 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
676 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
677 }
678 }
679
680 auto newMode = (opNanMode != clampNanMode)
681 ? tosa::NanPropagationMode::IGNORE
682 : opNanMode;
683
684 auto newModeAttr =
685 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
686
687 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
688 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
689 newModeAttr);
690 return success();
691 }
692};
693
694void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
695 MLIRContext *context) {
696 results.add<ClampIsNoOp>(context);
697 results.add<ClampClampOptimization>(context);
698}
699
700struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
701 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
702
703 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
704 PatternRewriter &rewriter) const override {
705 Value sliceInput = sliceOp.getInput1();
706 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
707 if (!concatOp)
708 return rewriter.notifyMatchFailure(
709 sliceOp, "slice input must be concat operation");
710
711 OperandRange inputs = concatOp.getInput1();
712 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
713 if (!concatType || !concatType.hasStaticShape())
714 return rewriter.notifyMatchFailure(
715 sliceOp, "slice input must be a static ranked tensor");
716 int32_t axis = concatOp.getAxis();
717
718 DenseElementsAttr startElems;
719 DenseElementsAttr sizeElems;
720
721 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
722 return rewriter.notifyMatchFailure(
723 sliceOp, "start of slice must be a static ranked shape");
724
725 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
726 return rewriter.notifyMatchFailure(
727 sliceOp, "size of slice must be a static ranked shape");
728
729 llvm::SmallVector<int64_t> sliceStarts =
730 llvm::to_vector(startElems.getValues<int64_t>());
731 llvm::SmallVector<int64_t> sliceSizes =
732 llvm::to_vector(sizeElems.getValues<int64_t>());
733
734 // Validate slice on the concatenated axis. Slicing along this
735 // axis should span only one of the inputs to the concatenate
736 // operation.
737 std::optional<Value> replaceWithSlice;
738 for (auto input : inputs) {
739 auto inputType = dyn_cast<RankedTensorType>(input.getType());
740 if (!inputType || !inputType.hasStaticShape())
741 return rewriter.notifyMatchFailure(
742 sliceOp, "concat input must be a static ranked tensor");
743
744 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
745 inputType.getDimSize(axis)) {
746 auto start_op =
747 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
748 auto size_op =
749 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
750 replaceWithSlice =
751 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
752 input, start_op, size_op)
753 .getResult();
754 break;
755 }
756 sliceStarts[axis] -= inputType.getDimSize(axis);
757 }
758
759 if (!replaceWithSlice)
760 return rewriter.notifyMatchFailure(
761 sliceOp, "corresponding concat input not found for slice");
762
763 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
764 return success();
765 }
766};
767
768struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
769 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
770
771 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
772 PatternRewriter &rewriter) const override {
773 Value sliceInput = sliceOp.getInput1();
774
775 // Check if producer is a PadOp
776 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
777 if (!padOp)
778 return rewriter.notifyMatchFailure(sliceOp,
779 "slice input must be a pad operation");
780
781 // Check PadOp has a single consumer
782 if (!padOp->hasOneUse())
783 return rewriter.notifyMatchFailure(sliceOp,
784 "pad shall have a single consumer");
785
786 // Check input is statically ranked
787 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
788 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
789 if (!inputTy || !padTy || !inputTy.hasRank())
790 return rewriter.notifyMatchFailure(sliceOp,
791 "slice input must be a ranked tensor");
792
793 // Validate and extract tosa::PadOp padding
794 DenseIntElementsAttr paddingElems;
795 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
796 return rewriter.notifyMatchFailure(
797 sliceOp,
798 "`padding` input specified on the tosa::PadOp must be constant.");
799 }
800 llvm::SmallVector<int64_t> padPaddings =
801 llvm::to_vector(paddingElems.getValues<int64_t>());
802
803 // Extract slice parameters
804 DenseElementsAttr startElems;
805 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
806 return rewriter.notifyMatchFailure(
807 sliceOp, "start of slice must be a static ranked shape");
808 llvm::SmallVector<int64_t> sliceStarts =
809 llvm::to_vector(startElems.getValues<int64_t>());
810
811 DenseElementsAttr sizeElems;
812 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
813 return rewriter.notifyMatchFailure(
814 sliceOp, "size of slice must be a static ranked shape");
815 llvm::SmallVector<int64_t> sliceSizes =
816 llvm::to_vector(sizeElems.getValues<int64_t>());
817
818 // Check if dynamic dimensions are sliced
819 const int64_t rank = inputTy.getRank();
820 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
821 const bool isDimDynamic = inputTy.isDynamicDim(i);
822 const bool isDimSliced =
823 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
824
825 return isDimDynamic && isDimSliced;
826 })) {
827 return rewriter.notifyMatchFailure(
828 sliceOp, "axis that are sliced shall be statically known.");
829 }
830
831 // Update the parameters
832 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
833 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
834 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
835 bool updated = false;
836
837 for (int64_t i = 0; i < rank; ++i) {
838 const int64_t padLo = padPaddings[i * 2];
839 const int64_t padHi = padPaddings[i * 2 + 1];
840 const int64_t sliceStart = sliceStarts[i];
841 const int64_t sliceSize = sliceSizes[i];
842 const int64_t sliceEnd = sliceStart + sliceSize;
843
844 // If dimension is dynamic pass-through
845 if (inputTy.isDynamicDim(i)) {
846 newPadPaddings[i * 2] = padLo;
847 newPadPaddings[i * 2 + 1] = padHi;
848 newSliceStarts[i] = sliceStart;
849 continue;
850 }
851
852 // Handle static dimensions
853 const int64_t dimSize = inputTy.getShape()[i];
854 const int64_t dimTotal = padLo + dimSize + padHi;
855
856 // Check slice within bounds
857 if (sliceStart < 0 || sliceEnd > dimTotal)
858 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
859
860 // Compute updated slice start parameter
861 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
862 newSliceStarts[i] = newSliceStart;
863 updated |= newSliceStart != sliceStart;
864
865 // Compute updated pad parameters
866 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
867 const int64_t newPadHi =
868 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
869 newPadPaddings[i * 2] = newPadLo;
870 newPadPaddings[i * 2 + 1] = newPadHi;
871 updated |= (newPadLo != padLo) || (newPadHi != padHi);
872
873 // Calculate new pad output shape
874 newPadShape[i] =
875 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
876 }
877
878 // Check that we actually need to proceed with the rewrite
879 if (!updated)
880 return rewriter.notifyMatchFailure(
881 sliceOp, "terminate condition; nothing to rewrite");
882
883 // Create a PadOp with updated padding
884 auto newPaddingsOp =
885 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
886 auto newPadTy =
887 RankedTensorType::get(newPadShape, inputTy.getElementType());
888 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
889 padOp.getInput1(), newPaddingsOp,
890 padOp.getPadConst());
891
892 // Update SliceOp and point to new PadOp
893 auto newStartOp =
894 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
895 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
896 newPadOp.getResult(), newStartOp,
897 sliceOp.getSize());
898
899 return success();
900 }
901};
902
903// Update size operand of tosa.slice if size has dynamic dims but corresponding
904// output dim is static
906 : public OpRewritePattern<tosa::SliceOp> {
907 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
908
909 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
910 PatternRewriter &rewriter) const override {
911 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
912 if (!resultType.hasRank())
913 return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
914
915 ElementsAttr sizeElems;
916 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
917 return rewriter.notifyMatchFailure(
918 sliceOp, "size of slice must be a static ranked shape");
919 }
920
921 llvm::SmallVector<int64_t> sliceSizes =
922 llvm::to_vector(sizeElems.getValues<int64_t>());
923
924 bool replaceSliceSize{false};
925 // if size op has kInferableDimSize indicating dynamic shape but
926 // corresponding dim on the output is statically known, update size to match
927 // with known output dim shape
928 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
929 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
930 sliceSizes[index] = resultType.getDimSize(index);
931 replaceSliceSize = true;
932 }
933 }
934
935 if (!replaceSliceSize) {
936 return rewriter.notifyMatchFailure(
937 sliceOp, "no dimension of size of slice is dynamic that resolves "
938 "to static output shape");
939 }
940
941 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
942 auto newSliceOp =
943 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
944 sliceOp.getInput1(), sliceOp.getStart(), size_op);
945
946 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
947 return success();
948 }
949};
950
951void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
952 MLIRContext *context) {
953 results.add<ConcatSliceOptimization, PadSliceOptimization,
954 SliceDynamicSizeCanonicalization>(context);
955}
956
957struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
958 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
959
960 LogicalResult matchAndRewrite(tosa::CastOp castOp,
961 PatternRewriter &rewriter) const override {
962 const Value castInput = castOp.getInput();
963 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
964 if (!innerCastOp)
965 return rewriter.notifyMatchFailure(castOp,
966 "input must be cast operation");
967
968 const Value innerCastInput = innerCastOp.getInput();
969
970 const ShapedType innerInputType =
971 llvm::cast<ShapedType>(innerCastInput.getType());
972 const ShapedType innerOutputType =
973 llvm::cast<ShapedType>(innerCastOp.getType());
974 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
975
976 const Type innerInputElemType = innerInputType.getElementType();
977 const Type innerOutputElemType = innerOutputType.getElementType();
978 const Type outerOutputElemType = outerOutputType.getElementType();
979
980 const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
981 outerOutputElemType};
982
983 if (llvm::any_of(types, [](const Type type) {
984 // Support a specific set of floating point types since we need to be
985 // careful in not introducing unsupported type combinations
986 return !(type.isInteger() ||
987 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
988 Float16Type, Float32Type>(type));
989 }))
990 return rewriter.notifyMatchFailure(
991 castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
992 "supported");
993
994 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
995 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
996 return rewriter.notifyMatchFailure(
997 castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
998 "legal in TOSA");
999 }
1000
1001 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1002 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1003 return rewriter.notifyMatchFailure(
1004 castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1005 "legal in TOSA");
1006 }
1007
1008 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1009 outerOutputElemType.isInteger()) {
1010 return rewriter.notifyMatchFailure(
1011 castOp, "avoid introducing fp8 -> integer casts which are not "
1012 "legal in TOSA");
1013 }
1014
1015 if (innerInputElemType.isInteger() &&
1016 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1017 return rewriter.notifyMatchFailure(
1018 castOp, "avoid introducing integer -> fp8 casts which are not "
1019 "legal in TOSA");
1020 }
1021
1022 if (llvm::isa<Float16Type>(innerInputElemType) &&
1023 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1024 return rewriter.notifyMatchFailure(
1025 castOp, "avoid introducing fp16 -> bf16 casts which are not "
1026 "legal in TOSA");
1027 }
1028
1029 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1030 llvm::isa<Float16Type>(outerOutputElemType)) {
1031 return rewriter.notifyMatchFailure(
1032 castOp, "avoid introducing bf16 -> fp16 casts which are not "
1033 "legal in TOSA");
1034 }
1035
1036 const auto isIntegerOneOfWidth = [](Type type, size_t bitwidth1,
1037 size_t bitwidth2) {
1038 return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
1039 };
1040
1041 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1042 outerOutputElemType.isInteger(64)) {
1043 return rewriter.notifyMatchFailure(
1044 castOp, "avoid introducing i8/i16 -> i64 casts which are not "
1045 "legal in TOSA");
1046 }
1047
1048 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1049 !outerOutputElemType.isInteger()) {
1050 return rewriter.notifyMatchFailure(
1051 castOp, "avoid introducing bool/i64 to float casts which are not "
1052 "supported in all versions of TOSA");
1053 }
1054
1055 if (!innerInputElemType.isInteger() &&
1056 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1057 return rewriter.notifyMatchFailure(
1058 castOp, "avoid introducing float to bool/i64 casts which are not "
1059 "supported in all versions of TOSA");
1060 }
1061
1062 // Check that the cast we're considering for removal is non-narrowing
1063 if (isNarrowingCast(innerInputType, innerOutputType))
1064 return rewriter.notifyMatchFailure(castOp,
1065 "inner cast operation is narrowing");
1066
1067 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
1068 innerCastInput);
1069
1070 return success();
1071 }
1072
1073 bool supportsNaN(const llvm::fltSemantics &semantics) const {
1074 return semantics.nonFiniteBehavior !=
1075 llvm::fltNonfiniteBehavior::FiniteOnly;
1076 }
1077
1078 bool supportsInf(const llvm::fltSemantics &semantics) const {
1079 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1080 }
1081
1082 bool isNarrowingCast(const ShapedType inType,
1083 const ShapedType outType) const {
1084
1085 if (inType.getElementType().isInteger() &&
1086 outType.getElementType().isInteger()) {
1087
1088 const auto inTypeSignedness =
1089 cast<IntegerType>(inType.getElementType()).getSignedness();
1090 const auto outTypeSignedness =
1091 cast<IntegerType>(outType.getElementType()).getSignedness();
1092
1093 return (inTypeSignedness != outTypeSignedness ||
1094 inType.getElementTypeBitWidth() >
1095 outType.getElementTypeBitWidth());
1096 }
1097
1098 if (inType.getElementType().isFloat() &&
1099 outType.getElementType().isFloat()) {
1100
1101 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1102 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1103 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1104 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1105
1106 // If the list of supported types needs to be updated in the future, the
1107 // check down below will need to be revised, for example to account for
1108 // unsigned floating point types, or types that use negative zero as the
1109 // representation for NaN.
1110 [[maybe_unused]] const auto isSupported = [](Type elemType) {
1111 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1112 Float16Type, Float32Type>(elemType);
1113 };
1114
1115 assert(isSupported(inElemTy) &&
1116 "unsupported input element type in isNarrowingCast");
1117 assert(isSupported(outElemTy) &&
1118 "unsupported output element type in isNarrowingCast");
1119
1120 return (
1121 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1122 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1123 inTypeSemantics.precision > outTypeSemantics.precision ||
1124 (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
1125 (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
1126 }
1127
1128 // While some cases of int -> float casts can be non-narrowing, consider
1129 // them narrowing for the purposes of this optimization
1130 return true;
1131 }
1132};
1133
1134void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1135 MLIRContext *context) {
1136 results.add<NonNarrowingCastsOptimization>(context);
1137}
1138
1140 : public OpRewritePattern<tosa::CastToBlockScaledOp> {
1141 using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
1142
1143 LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
1144 PatternRewriter &rewriter) const override {
1145 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1146 auto castFromBlockScaledOp =
1147 castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
1148 if (!castFromBlockScaledOp)
1149 return rewriter.notifyMatchFailure(
1150 castToBlockScaledOp,
1151 "input must be cast_from_block_scaled operation");
1152
1153 const Value innerData = castFromBlockScaledOp.getInputData();
1154 const Value innerScale = castFromBlockScaledOp.getInputScale();
1155 const auto innerDataTy = llvm::cast<ShapedType>(innerData.getType());
1156 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.getType());
1157
1158 const Value outerData = castToBlockScaledOp.getOutputData();
1159 const Value outerScale = castToBlockScaledOp.getOutputScale();
1160 const auto outerDataTy = llvm::cast<ShapedType>(outerData.getType());
1161 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.getType());
1162
1163 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1164 return rewriter.notifyMatchFailure(
1165 castToBlockScaledOp,
1166 "inputs types to cast_from_block_scaled operation must match output "
1167 "types to cast_to_block_scaled");
1168 }
1169
1170 if (castFromBlockScaledOp.getBlockSize() !=
1171 castToBlockScaledOp.getBlockSize()) {
1172 return rewriter.notifyMatchFailure(
1173 castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
1174 "cast_to_block_scaled must match");
1175 }
1176
1177 rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
1178
1179 return success();
1180 }
1181};
1182
1183void CastToBlockScaledOp::getCanonicalizationPatterns(
1184 RewritePatternSet &results, MLIRContext *context) {
1185 results.add<CancellingBlockScaledCastsOptimization>(context);
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// Operator Folders.
1190//===----------------------------------------------------------------------===//
1191
1192template <typename Folder>
1193static DenseElementsAttr
1195 bool foldDenseValues = false) {
1196 if (!lhs || !rhs)
1197 return {};
1198
1199 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1200 return {};
1201
1202 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
1203 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
1204 if (lETy != rETy)
1205 return {};
1206
1207 if (lhs.isSplat() && rhs.isSplat()) {
1208 if (isa<FloatType>(lETy)) {
1209 const APFloat l = lhs.getSplatValue<APFloat>();
1210 const APFloat r = rhs.getSplatValue<APFloat>();
1211 const auto maybeResult = Folder::fold(l, r);
1212 if (failed(maybeResult))
1213 return {};
1214 return DenseElementsAttr::get(returnTy, maybeResult.value());
1215 }
1216
1217 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1218 const APInt l = lhs.getSplatValue<APInt>();
1219 const APInt r = rhs.getSplatValue<APInt>();
1220 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1221 if (failed(maybeResult))
1222 return {};
1223 return DenseElementsAttr::get(returnTy, maybeResult.value());
1224 }
1225 }
1226
1227 if (foldDenseValues) {
1228 assert(lETy.isIntOrIndex() &&
1229 "Only integer types are currently supported.");
1230 SmallVector<APInt> resultValues;
1231 for (auto [l, r] :
1232 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
1233 const auto maybeResult = Folder::fold(l, r, false);
1234 if (failed(maybeResult))
1235 return {};
1236 resultValues.push_back(maybeResult.value());
1237 }
1238 return DenseElementsAttr::get(returnTy, resultValues);
1239 }
1240
1241 return {};
1242}
1243
1244template <typename Folder>
1245static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
1246 bool foldDenseValues = false) {
1247 if (!val)
1248 return {};
1249
1250 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1251 return {};
1252
1253 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
1254
1255 if (val.isSplat()) {
1256 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1257 const APInt v = val.getSplatValue<APInt>();
1258 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1259 if (failed(maybeResult))
1260 return {};
1261 return DenseElementsAttr::get(returnTy, maybeResult.value());
1262 }
1263 }
1264
1265 if (foldDenseValues) {
1266 mlir::Type elemTy = val.getElementType();
1267 if (elemTy.isIntOrIndex()) {
1268 SmallVector<APInt> resultValues;
1269 for (auto const &v : val.getValues<APInt>()) {
1270 const auto maybeResult = Folder::fold(v, false);
1271 if (failed(maybeResult))
1272 return {};
1273 resultValues.push_back(maybeResult.value());
1274 }
1275 return DenseElementsAttr::get(returnTy, resultValues);
1276 }
1277 }
1278
1279 // Folding arbitrarily sized tensor operations is not supported
1280 return {};
1281}
1282
1283static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
1284 DenseIntElementsAttr dense{};
1285 if (!matchPattern(v, m_Constant(&dense)))
1286 return failure();
1287
1288 assert(dense.isSplat());
1289 APInt a = dense.getSplatValue<APInt>();
1290 return a.getSExtValue();
1291}
1292
1294 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1295 const bool isUnsigned) {
1296 bool overflow;
1297 const APInt result =
1298 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1299 if (overflow)
1300 return failure();
1301 return result;
1302 }
1303
1304 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1305 return lhs + rhs;
1306 }
1307};
1308
1310 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1311 const bool isUnsigned) {
1312 bool overflow;
1313 const APInt result =
1314 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1315 if (overflow)
1316 return failure();
1317 return result;
1318 }
1319
1320 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1321 return lhs - rhs;
1322 }
1323};
1324
1326 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1327 const bool isUnsigned) {
1328
1329 const unsigned originalWidth = lhs.getBitWidth();
1330
1331 // Check same type
1332 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1333 return failure();
1334 }
1335
1336 // If either is `0`
1337 if (lhs == 0 || rhs == 0)
1338 return APInt::getZero(originalWidth);
1339
1340 bool overflow = false;
1341 APInt const result =
1342 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1343
1344 if (overflow)
1345 return failure();
1346
1347 return result.trunc(originalWidth);
1348 }
1349
1350 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1351 return lhs * rhs;
1352 }
1353};
1354
1355static bool signsDiffer(const APInt &a, const APInt &b) {
1356 return a.isNegative() != b.isNegative();
1357}
1358
1359template <bool Ceil>
1361 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1362 bool isUnsigned) {
1363 if (lhs.getBitWidth() != rhs.getBitWidth())
1364 return failure();
1365 if (rhs.isZero())
1366 return failure();
1367
1368 if (isUnsigned) {
1369 APInt q{};
1370 APInt r{};
1371 APInt::udivrem(lhs, rhs, q, r);
1372 if (!r.isZero() && Ceil) {
1373 return q + 1;
1374 }
1375 return q;
1376 }
1377
1378 // Signed: start from trunc-toward-zero, then adjust to ceil.
1379 bool overflow{false};
1380 APInt const q = lhs.sdiv_ov(rhs, overflow);
1381 if (overflow)
1382 return failure();
1383 APInt const r = lhs.srem(rhs);
1384
1385 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1386 // Same sign => exact quotient is positive; trunc is below ceil =>
1387 // increment q.
1388 return q + 1;
1389 }
1390 return q;
1391 }
1392
1393 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1394 return lhs / rhs;
1395 }
1396};
1397
1399 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1400 bool isUnsigned) {
1401 if (lhs.getBitWidth() != rhs.getBitWidth())
1402 return failure();
1403 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1404 return failure();
1405
1406 if (isUnsigned) {
1407 return lhs.urem(rhs);
1408 }
1409
1410 return lhs.srem(rhs);
1411 }
1412
1413 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1414 auto t = lhs;
1415 auto const r = t.mod(rhs);
1416 if (llvm::APFloatBase::opStatus::opOK == r) {
1417 return t;
1418 }
1419 return failure();
1420 }
1421};
1422
1424 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1425 bool isUnsigned) {
1426 if (lhs.getBitWidth() != rhs.getBitWidth())
1427 return failure();
1428 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1429 }
1430
1431 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1432 return lhs >= rhs ? lhs : rhs;
1433 }
1434};
1435
1437 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1438 bool isUnsigned) {
1439 if (lhs.getBitWidth() != rhs.getBitWidth())
1440 return failure();
1441 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1442 }
1443
1444 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1445 return lhs <= rhs ? lhs : rhs;
1446 }
1447};
1448
1450 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1451 auto const numBits = value.getBitWidth();
1452 if (isUnsigned) {
1453 auto const zextv = value.getZExtValue();
1454 if (zextv >= numBits)
1455 return failure();
1456 return APInt::getOneBitSet(numBits, zextv);
1457 }
1458 auto const sextv = value.getSExtValue();
1459 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1460 return failure();
1461 return APInt::getOneBitSet(numBits, sextv);
1462 }
1463};
1464
1466 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1467 if (!value.isStrictlyPositive())
1468 return failure();
1469 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1470 }
1471};
1472
1474 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1475 if (!value.isStrictlyPositive())
1476 return failure();
1477 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1478 }
1479};
1480
1482 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1483 const bool isUnsigned) {
1484 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1485 }
1486
1487 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1488 return APInt(1, lhs > rhs);
1489 }
1490};
1491
1493 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1494 const bool isUnsigned) {
1495 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1496 }
1497
1498 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1499 return APInt(1, lhs >= rhs);
1500 }
1501};
1502
1504 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1505 const bool isUnsigned) {
1506 return APInt(1, lhs == rhs);
1507 }
1508
1509 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1510 return APInt(1, lhs == rhs);
1511 }
1512};
1513
1514static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1515 if (llvm::isa<FloatType>(elemType))
1516 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1517 if (llvm::isa<IntegerType>(elemType))
1518 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1519 return false;
1520}
1521
1522static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1523 if (llvm::isa<FloatType>(elemType))
1524 return val && val.isSplat() &&
1525 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1526 if (llvm::isa<IntegerType>(elemType)) {
1527 const int64_t shifted = 1LL << shift;
1528 return val && val.isSplat() &&
1529 val.getSplatValue<APInt>().getSExtValue() == shifted;
1530 }
1531 return false;
1532}
1533
1534OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1535 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1536 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1537 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1538 if (!lhsTy || !rhsTy || !resultTy)
1539 return {};
1540
1541 // Cannot create an ElementsAttr from non-int/float/index types
1542 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1543 !rhsTy.getElementType().isIntOrIndexOrFloat())
1544 return {};
1545
1546 auto resultETy = resultTy.getElementType();
1547 auto lhsAttr =
1548 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1549 auto rhsAttr =
1550 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1551
1552 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1553 lhsTy.getShape(), rhsTy.getShape());
1554 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1555 return getInput1();
1556 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1557 return getInput2();
1558
1559 if (!lhsAttr || !rhsAttr)
1560 return {};
1561
1562 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1563}
1564
1565OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1566 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1567 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1568 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1569 !outputTy.hasStaticShape())
1570 return {};
1571
1572 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1573 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1574 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1575 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1576 return DenseElementsAttr::get(outputTy, zero);
1577 }
1578
1579 return {};
1580}
1581
1582OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1583 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1584 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1585 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1586 if (!lhsTy || !rhsTy || !resultTy)
1587 return {};
1588 if (lhsTy.getElementType() != rhsTy.getElementType())
1589 return {};
1590
1591 // IntDivOp inputs must be integer type, no need to check for quantized
1592 // type
1593 auto resultETy = resultTy.getElementType();
1594 auto lhsAttr =
1595 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1596 auto rhsAttr =
1597 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1598 if (lhsAttr && lhsAttr.isSplat()) {
1599 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1600 lhsAttr.getSplatValue<APInt>().isZero())
1601 return lhsAttr.resizeSplat(resultTy);
1602 }
1603
1604 if (rhsAttr && rhsAttr.isSplat()) {
1605 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && lhsTy == resultTy &&
1608 llvm::isa<IntegerType>(resultETy) &&
1609 rhsAttr.getSplatValue<APInt>().isOne())
1610 return getInput1();
1611 }
1612
1613 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1614 llvm::isa<IntegerType>(resultETy)) {
1615 APInt l = lhsAttr.getSplatValue<APInt>();
1616 APInt r = rhsAttr.getSplatValue<APInt>();
1617 if (!r.isZero()) {
1618 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1619 auto const result =
1620 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1621 if (failed(result))
1622 return {};
1623 return DenseElementsAttr::get(resultTy, result.value());
1624 }
1625 }
1626
1627 return {};
1628}
1629
1630namespace {
1631// calculate lhs * rhs >> shift according to TOSA Spec
1632// return nullopt if result is not in range of int32_t when shift > 0
1633std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1634 unsigned bitwidth) {
1635 bool overflow = false;
1636 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1637
1638 if (overflow)
1639 return std::nullopt;
1640
1641 if (shift > 0) {
1642 auto round = APInt(64, 1) << (shift - 1);
1643 result += round;
1644 result.ashrInPlace(shift);
1645 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1646 // maximum_s<i32_t>())
1647 if (!(result.getSExtValue() >= INT32_MIN &&
1648 result.getSExtValue() <= INT32_MAX)) {
1649 // REQUIRE failed
1650 return std::nullopt;
1651 }
1652 }
1653
1654 return result.trunc(bitwidth);
1655}
1656
1657DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1658 RankedTensorType ty, int32_t shift) {
1659 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1660 if (llvm::isa<IntegerType>(ty.getElementType())) {
1661 APInt l = lhs.getSplatValue<APInt>();
1662 APInt r = rhs.getSplatValue<APInt>();
1663
1664 if (shift == 0) {
1665 return DenseElementsAttr::get(ty, l * r);
1666 }
1667
1668 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1669 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1670 if (!result)
1671 return {};
1672 return DenseElementsAttr::get(ty, result.value());
1673 }
1674
1675 if (llvm::isa<FloatType>(ty.getElementType())) {
1676 APFloat l = lhs.getSplatValue<APFloat>();
1677 APFloat r = rhs.getSplatValue<APFloat>();
1678 APFloat result = l * r;
1679 return DenseElementsAttr::get(ty, result);
1680 }
1681 }
1682
1683 return {};
1684}
1685} // namespace
1686
1687OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1688 auto lhs = getInput1();
1689 auto rhs = getInput2();
1690 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1691 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1692 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1693 if (!lhsTy || !rhsTy || !resultTy)
1694 return {};
1695
1696 auto resultETy = resultTy.getElementType();
1697 auto lhsAttr =
1698 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1699 auto rhsAttr =
1700 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1701
1702 // Result right shift on i32_t data type only. For simplification,
1703 // synthesize a zero shift for other data type.
1704 int32_t shift = 0;
1705 if (resultETy.isInteger(32)) {
1706 ElementsAttr shift_elem;
1707 if (getShift().getImpl()) {
1708 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1709 // cannot be folded when the shift value is unknown.
1710 return {};
1711 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1712 }
1713 }
1714
1715 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1716 resultTy.hasStaticShape())
1717 // constant values can only be resized if resulting type is static
1718 return lhsAttr.resizeSplat(resultTy);
1719 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1720 resultTy.hasStaticShape())
1721 return rhsAttr.resizeSplat(resultTy);
1722
1723 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1724 lhsTy.getShape(), rhsTy.getShape());
1725 if (isBroadcastable && rhsTy == resultTy &&
1726 isSplatOne(resultETy, lhsAttr, shift))
1727 return rhs;
1728 if (isBroadcastable && lhsTy == resultTy &&
1729 isSplatOne(resultETy, rhsAttr, shift))
1730 return lhs;
1731
1732 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1733}
1734
1735OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1736 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1737 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1738 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1739 if (!lhsTy || !rhsTy || !resultTy)
1740 return {};
1741
1742 // Cannot create an ElementsAttr from non-int/float/index types
1743 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1744 !rhsTy.getElementType().isIntOrIndexOrFloat())
1745 return {};
1746
1747 auto resultETy = resultTy.getElementType();
1748 auto lhsAttr =
1749 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1750 auto rhsAttr =
1751 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1752
1753 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1754 lhsTy.getShape(), rhsTy.getShape());
1755 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1756 return getInput1();
1757
1758 if (!lhsAttr || !rhsAttr)
1759 return {};
1760
1761 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1762}
1763
1764OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1765 auto resultTy = llvm::cast<ShapedType>(getType());
1766 auto lhsAttr =
1767 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1768 auto rhsAttr =
1769 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1770
1771 if (!lhsAttr || !rhsAttr)
1772 return {};
1773
1774 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1775}
1776
1777OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1778 auto resultTy = llvm::cast<ShapedType>(getType());
1779 auto lhsAttr =
1780 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1781 auto rhsAttr =
1782 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1783
1784 if (!lhsAttr || !rhsAttr)
1785 return {};
1786
1787 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1788}
1789
1790OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1791 auto resultTy = llvm::cast<ShapedType>(getType());
1792 auto lhsAttr =
1793 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1794 auto rhsAttr =
1795 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1796 Value lhs = getInput1();
1797 Value rhs = getInput2();
1798 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1799
1800 // If we are comparing an integer value to itself it is always true. We
1801 // can not do this with float due to float values.
1802 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1803 resultTy.hasStaticShape() && lhs == rhs) {
1804 return DenseElementsAttr::get(resultTy, true);
1805 }
1806
1807 if (!lhsAttr || !rhsAttr)
1808 return {};
1809
1810 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1811}
1812
1813OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1814 if (getInput().getType() == getType())
1815 return getInput();
1816
1817 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1818 if (!operand)
1819 return {};
1820
1821 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1822 auto outTy = llvm::cast<ShapedType>(getType());
1823 if (!outTy.hasRank() || !outTy.hasStaticShape())
1824 return {};
1825 auto inETy = inTy.getElementType();
1826 auto outETy = outTy.getElementType();
1827
1828 if (operand.isSplat()) {
1829 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1830 bool overflow;
1831 auto splatVal = operand.getSplatValue<APFloat>();
1832 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1833 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1834 &overflow);
1835 return SplatElementsAttr::get(outTy, splatVal);
1836 }
1837
1838 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1839 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1840 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1841 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1842 llvm::RoundingMode::NearestTiesToEven);
1843 return SplatElementsAttr::get(outTy, splatVal);
1844 }
1845
1846 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1847 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1848 auto intVal = APSInt(
1849 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1850 auto floatVal = operand.getSplatValue<APFloat>();
1851 bool exact;
1852 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1853 &exact);
1854 return SplatElementsAttr::get(outTy, intVal);
1855 }
1856
1857 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1858 const auto inIntType = llvm::cast<IntegerType>(inETy);
1859 auto unsignIn = inIntType.isUnsignedInteger();
1860 bool trunc =
1861 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1862 auto intVal = operand.getSplatValue<APInt>();
1863 auto bitwidth = outETy.getIntOrFloatBitWidth();
1864
1865 // i1 types are boolean in TOSA
1866 if (outETy.isInteger(1)) {
1867 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1868 } else if (trunc) {
1869 intVal = intVal.trunc(bitwidth);
1870 } else if (unsignIn || inIntType.isInteger(1)) {
1871 intVal = intVal.zext(bitwidth);
1872 } else {
1873 intVal = intVal.sext(bitwidth);
1874 }
1875
1876 return SplatElementsAttr::get(outTy, intVal);
1877 }
1878 }
1879
1880 return {};
1881}
1882
1883OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1884
1885OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1886
1887#define REDUCE_FOLDER(OP) \
1888 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1889 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1890 if (!inputTy.hasRank()) \
1891 return {}; \
1892 if (inputTy != getType()) \
1893 return {}; \
1894 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1895 return getInput(); \
1896 return {}; \
1897 }
1898
1899REDUCE_FOLDER(ReduceAllOp)
1900REDUCE_FOLDER(ReduceAnyOp)
1901REDUCE_FOLDER(ReduceMaxOp)
1902REDUCE_FOLDER(ReduceMinOp)
1903REDUCE_FOLDER(ReduceProductOp)
1904REDUCE_FOLDER(ReduceSumOp)
1905#undef REDUCE_FOLDER
1906
1907OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1908 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1909 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1910
1911 if (!inputTy || !outputTy)
1912 return {};
1913
1914 // Fold when the input and output types are the same. This is only safe
1915 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1916 // dimensions, there may still be a productive reshape.
1917 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1918 return getInput1();
1919
1920 // reshape(reshape(x)) -> reshape(x)
1921 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1922 getInput1().getDefiningOp())) {
1923 getInput1Mutable().assign(reshapeOp.getInput1());
1924 return getResult();
1925 }
1926
1927 // Cannot create an ElementsAttr from non-int/float/index types
1928 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1929 return {};
1930
1931 // reshape(const(x)) -> const(reshape-attr(x))
1932 if (auto operand =
1933 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1934 // Constants must have static shape.
1935 if (!outputTy.hasStaticShape())
1936 return {};
1937
1938 // Okay to duplicate splat constants.
1939 if (operand.isSplat())
1940 return SplatElementsAttr::get(outputTy,
1941 operand.getSplatValue<Attribute>());
1942
1943 // Don't duplicate other constants.
1944 if (!getInput1().hasOneUse())
1945 return {};
1946
1948 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1949 return {};
1950
1951 return operand.reshape(
1952 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1953 }
1954
1955 return {};
1956}
1957
1958OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1959 // If the pad is all zeros we can fold this operation away.
1960 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1961 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1962 if (densePad && densePad.isSplat() &&
1963 densePad.getSplatValue<APInt>().isZero()) {
1964 return getInput1();
1965 }
1966 }
1967
1968 return {};
1969}
1970
1971// Fold away cases where a tosa.resize operation returns a copy
1972// of the input image.
1973OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1974 auto scaleAttr =
1975 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1976 auto offsetAttr =
1977 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1978 auto borderAttr =
1979 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1980 if (!scaleAttr || !offsetAttr || !borderAttr) {
1981 return {};
1982 }
1983
1984 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1985 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1986 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1987 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1988 return {};
1989 }
1990
1991 // Check unit scaling.
1992 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1993 return {};
1994 }
1995
1996 // There should be no offset.
1997 if (offset[0] != 0 || offset[1] != 0) {
1998 return {};
1999 }
2000
2001 // There should be no border.
2002 if (border[0] != 0 || border[1] != 0) {
2003 return {};
2004 }
2005
2006 return foldToInputIfTypeMatches(getType(), getInput());
2007}
2008
2009OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2010 auto operand = getInput1();
2011 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2012 auto axis = getAxis();
2013 // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
2014 const bool isSplatInput =
2015 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2016 if (!operandTy.hasRank() ||
2017 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2018 return {};
2019 return foldToInputIfTypeMatches(getType(), operand);
2020}
2021
2022OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2023 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2024 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
2025
2026 if (!inputTy || !outputTy)
2027 return {};
2028
2029 if (inputTy == outputTy && inputTy.hasStaticShape())
2030 return getInput1();
2031
2032 // Check if this is a no-op slice (starts at 0 and size matches input)
2033
2034 DenseElementsAttr startElems;
2035 if (!matchPattern(getStart(), m_Constant(&startElems)))
2036 return {};
2037
2038 // Check if all start values are zero
2039 bool startIsZeros =
2040 llvm::all_of(startElems.getValues<APInt>(),
2041 [](const APInt &val) { return val.isZero(); });
2042
2043 if (startIsZeros) {
2044
2045 // Check if size matches input shape
2046 DenseElementsAttr sizeElems;
2047 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
2048 return {};
2049
2050 auto inputShape = inputTy.getShape();
2051 auto sizeValues = sizeElems.getValues<APInt>();
2052
2053 bool sizeMatchesInput = true;
2054 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2055 int64_t size = sizeVal.getSExtValue();
2056
2057 if (inputTy.isDynamicDim(i)) {
2058 // For dynamic dimensions, check for kInferableDimSize indicating full
2059 // dimension is sliced
2060 if (size != kInferableDimSize) {
2061 sizeMatchesInput = false;
2062 break;
2063 }
2064 } else {
2065 // For static dimensions, check that size must match exactly or be
2066 // kInferableDimSize indicating full dimension is sliced
2067 if (size != kInferableDimSize && size != inputShape[i]) {
2068 sizeMatchesInput = false;
2069 break;
2070 }
2071 }
2072 }
2073
2074 if (sizeMatchesInput)
2075 return getInput1();
2076 }
2077
2078 // The following checks require the input to be a constant
2079 if (!adaptor.getInput1())
2080 return {};
2081
2082 // Cannot create an ElementsAttr from non-int/float/index types
2083 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2084 !outputTy.getElementType().isIntOrIndexOrFloat())
2085 return {};
2086
2087 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2088 if (operand.isSplat() && outputTy.hasStaticShape()) {
2089 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
2090 }
2091
2092 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2093 outputTy.getNumElements() == 1) {
2094 llvm::SmallVector<uint64_t> indices =
2095 llvm::to_vector(startElems.getValues<uint64_t>());
2096 if (auto values = operand.tryGetValues<Attribute>())
2097 return SplatElementsAttr::get(outputTy, (*values)[indices]);
2098 }
2099
2100 return {};
2101}
2102
2103OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2104 const Value pred = getPred();
2105 const Value onTrue = getOnTrue();
2106 const Value onFalse = getOnFalse();
2107
2108 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
2109 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
2110 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
2111 if (!predTy || !onTrueTy || !onFalseTy)
2112 return {};
2113
2114 const Type resultTy = getType();
2115
2116 const ArrayRef<int64_t> predShape = predTy.getShape();
2117 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2118
2119 if (onTrue == onFalse && onTrueTy == resultTy &&
2120 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
2121 return onTrue;
2122
2123 auto predicate =
2124 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2125 if (!predicate)
2126 return {};
2127 if (!predicate.isSplat())
2128 return {};
2129
2130 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2131
2132 SmallVector<SmallVector<int64_t>, 3> shapes;
2133 shapes.emplace_back(predShape);
2134 shapes.emplace_back(onTrueShape);
2135 shapes.emplace_back(onFalseTy.getShape());
2136 const bool isBroadcastable =
2138
2139 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
2140 return onTrue;
2141 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
2142 return onFalse;
2143 return {};
2144}
2145
2146OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2147 if (getInput1().getType() == getType()) {
2148 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2149 adaptor.getMultiples())) {
2150 if (multiples.isSplat() &&
2151 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2152 return getInput1();
2153 if (auto int_array_attr =
2154 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2155 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2156 [](APInt v) { return v.getSExtValue() == 1; }))
2157 return getInput1();
2158 }
2159 }
2160 }
2161 return {};
2162}
2163
2164OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2165 auto resultTy = llvm::cast<ShapedType>(getType());
2166
2167 // Transposing splat values just means reshaping.
2168 if (auto input =
2169 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2170 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2171 input.getType().getElementType() == resultTy.getElementType())
2172 return input.reshape(resultTy);
2173 }
2174
2175 // Transpose is not the identity transpose.
2176 const llvm::ArrayRef<int32_t> perms = getPerms();
2177
2178 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2179 return {};
2180
2181 return foldToInputIfTypeMatches(getType(), getInput1());
2182}
2183
2184OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2185 // Element-wise negate(negate(x)) = x
2186 // iff all zero points are constant 0
2187 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2188 if (!definingOp) {
2189 // defining op of input1 is not a negate, cannot fold
2190 return {};
2191 }
2192
2193 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2194 failed(maybeIZp) || *maybeIZp != 0) {
2195 // input1 zero point is not constant 0, cannot fold
2196 return {};
2197 }
2198 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2199 failed(maybeOZp) || *maybeOZp != 0) {
2200 // output zero point is not constant 0, cannot fold
2201 return {};
2202 }
2203 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2204 failed(maybeIZp) || *maybeIZp != 0) {
2205 // definingOp's input1 zero point is not constant 0, cannot fold
2206 return {};
2207 }
2208 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2209 failed(maybeOZp) || *maybeOZp != 0) {
2210 // definingOp's output zero point is not constant 0, cannot fold
2211 return {};
2212 }
2213
2214 return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
2215}
2216
2217OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2218 auto input = getInput1();
2219 // Element-wise abs(abs(x)) = abs(x)
2220 if (input.getDefiningOp<tosa::AbsOp>())
2221 return foldToInputIfTypeMatches(getType(), input);
2222
2223 return {};
2224}
2225
2226OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2227 // Fold consecutive concats on the same axis into a single op.
2228 // Keep track of the operands so we are able to construct a new concat
2229 // later. Conservatively assume that we double the number of operands when
2230 // folding
2231 SmallVector<Value, 8> concatOperands;
2232 concatOperands.reserve(2 * getNumOperands());
2233
2234 // Find all operands that are foldable concats
2235 bool foundFoldableConcat = false;
2236 for (Value operand : getOperands()) {
2237 concatOperands.emplace_back(operand);
2238
2239 auto producer = operand.getDefiningOp<ConcatOp>();
2240 if (!producer)
2241 continue;
2242
2243 // Not foldable if axes are not the same
2244 if (getAxis() != producer.getAxis())
2245 continue;
2246
2247 // Replace the original operand with all incoming operands
2248 foundFoldableConcat = true;
2249 concatOperands.pop_back();
2250 llvm::append_range(concatOperands, producer->getOperands());
2251 }
2252
2253 if (!foundFoldableConcat)
2254 return {};
2255
2256 getOperation()->setOperands(concatOperands);
2257 return getResult();
2258}
2259
2260OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2261 auto input = adaptor.getInput1();
2262
2263 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2264 // Fold splat inputs only.
2265 if (!inputAttr || !inputAttr.isSplat())
2266 return {};
2267
2268 auto shapeType = llvm::cast<ShapedType>(getType());
2269 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2270 return {};
2271 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2272 auto floatVal = inputAttr.getSplatValue<APFloat>();
2273 return DenseElementsAttr::get(shapeType,
2274 ReciprocalOp::calcOneElement(floatVal));
2275 }
2276
2277 return {};
2278}
2279
2280template <typename Op, typename OpFoldAdaptor>
2282 auto input1ConstShape =
2283 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2284 if (!input1ConstShape)
2285 return {};
2286
2287 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2288
2289 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2290 /*foldDenseValues=*/true);
2291}
2292
2293template <typename Op, typename OpFoldAdaptor>
2295 auto input1ConstShape =
2296 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2297 auto input2ConstShape =
2298 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2299 if (!input1ConstShape || !input2ConstShape)
2300 return {};
2301
2302 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2303 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2304
2305 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2306 input1Attr.getType(),
2307 /*foldDenseValues=*/true);
2308}
2309
2310OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2311 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2312 if (!inputTy || !inputTy.hasRank())
2313 return {};
2314 const int32_t axis = getAxis();
2315 const int64_t dimSize = inputTy.getDimSize(axis);
2316 if (ShapedType::isDynamic(dimSize))
2317 return {};
2318
2319 OpBuilder builder(getContext());
2320 const auto resultAttrTy =
2321 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2322 return DenseElementsAttr::get(resultAttrTy, dimSize);
2323}
2324
2325OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
2326 auto const inputs = op->getInput();
2327
2328 if (inputs.empty())
2329 return {};
2330
2331 SmallVector<APInt> concatDims;
2332 concatDims.reserve(/*max elem*/ 64);
2333 for (auto const &v : inputs) {
2334 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2335 if (!vConstShape)
2336 return {};
2337
2338 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2339 assert(vAttr);
2340
2341 auto const vAttrVals = vAttr.getValues<APInt>();
2342 for (auto const &v : vAttrVals) {
2343 concatDims.push_back(v);
2344 }
2345 }
2346
2347 auto *ctx = op->getContext();
2348 assert(ctx != nullptr && "ctx is nullptr");
2349 auto const rankedTy = RankedTensorType::get(
2350 {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2351
2352 return DenseElementsAttr::get(rankedTy, concatDims);
2353}
2354
2355OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
2356 auto const input1 = op->getInput();
2357 auto const input2 = op->getStart();
2358 auto const input3 = op->getSize();
2359
2360 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2361
2362 if (!input1ConstShape)
2363 return {};
2364
2365 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2366 if (!input1Attr)
2367 return {};
2368
2369 auto const input1Vals = input1Attr.getValues<APInt>();
2370 auto const totalInput1 = input1Vals.size();
2371
2372 auto const start = getSingleI64From1ElementTensor(input2);
2373 auto const size = getSingleI64From1ElementTensor(input3);
2374
2375 if (failed(start) || failed(size))
2376 return {};
2377
2378 auto const startV = static_cast<int32_t>(start.value());
2379 auto const sizeV = static_cast<int32_t>(size.value());
2380
2381 if ((sizeV <= 0) || (startV < 0) ||
2382 (static_cast<size_t>(startV + sizeV) > totalInput1))
2383 return {};
2384
2385 SmallVector<APInt> sliceOfInput;
2386 sliceOfInput.reserve(totalInput1);
2387
2388 for (auto i = startV; i < (startV + sizeV); i++) {
2389 sliceOfInput.push_back(input1Vals[i]);
2390 }
2391
2392 auto *ctx = op->getContext();
2393 assert(ctx != nullptr && "ctx is nullptr");
2394
2395 auto const rankedTy = RankedTensorType::get(
2396 {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2397
2398 return DenseElementsAttr::get(rankedTy, sliceOfInput);
2399}
2400
2401OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2403}
2404
2405OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2407}
2408
2409OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2411}
2412
2413OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2414 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2415}
2416
2417OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2418 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2419}
2420
2421OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2423}
2424
2425OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2427}
2428
2429OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2431}
2432
2433OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2435}
2436
2437OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2439}
2440
2441OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2443}
2444
2445OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2446 return concatShapeFold(this);
2447}
2448
2449OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
2450 return sliceShapeFold(this);
2451}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...