MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/Dialect/Traits.h"
23#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/APInt.h"
29
30#include <functional>
31
32using namespace mlir;
33using namespace mlir::tosa;
34
35namespace {
36OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
37 return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
38}
39} // namespace
40
41//===----------------------------------------------------------------------===//
42// Operator Canonicalizers.
43//===----------------------------------------------------------------------===//
44
45//===----------------------------------------------------------------------===//
46// Tensor Data Engine Operators.
47//===----------------------------------------------------------------------===//
48
49// Check that the zero point of the tensor and padding operations are aligned.
50static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
51 // Check that padConst is a constant value and a scalar tensor
52 DenseElementsAttr padConstAttr;
53 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
54 (padConstAttr.size() != 1)) {
55 return false;
56 }
57
58 // Check that floating point pad is zero
59 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
60 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
61 return padConstVal == 0.0f;
62 }
63
64 // Check that the zp and padConst align for the integer (quantized) case
65 if (auto padConstIntAttr =
66 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
68 // Check that zp is a constant value and a scalar tensor
69 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
70 return false;
71 }
72
73 // Check equality
74 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
75 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
76 return zpVal == padConstVal;
77 }
78
79 // Bail-out on unsupported type
80 return false;
81}
82
83namespace {
84template <typename OpTy>
85struct PoolPadFoldAdaptor;
86
87template <>
88struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
89 using OpTy = tosa::MaxPool2dOp;
90 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
91 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
92 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
93 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
94 return false;
95 return true;
96 }
97 static bool checkPadConstCompliance(OpTy, Value padConst) {
98 // Check that padConst is a constant value and a scalar tensor
99 DenseElementsAttr padConstAttr;
100 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
101 padConstAttr.size() != 1) {
102 return false;
103 }
104
105 // Pad needs to be in the minimum value to be able to merge
106 if (auto padConstFpAttr =
107 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
108 const APFloat padConstVal = *padConstFpAttr.begin();
109 const APFloat lowestVal =
110 APFloat::getLargest(padConstVal.getSemantics(), true);
111 return padConstVal == lowestVal;
112 }
113 if (auto padConstIntAttr =
114 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
115 const APInt padConstVal = *padConstIntAttr.begin();
116 const unsigned int bitWidth = padConstVal.getBitWidth();
117 const APInt lowestVal =
118 padConstIntAttr.getElementType().isUnsignedInteger()
119 ? APInt::getZero(bitWidth)
120 : APInt::getSignedMinValue(bitWidth);
121 return padConstVal == lowestVal;
122 }
123
124 // Bail-out on unsupported type
125 return false;
126 }
127 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
128 Value padInput, ArrayRef<int64_t> newPad) {
129 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
130 op, op.getType(), padInput, op.getKernel(), op.getStride(),
131 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
132 }
133};
134
135template <typename OpTy>
136struct ConvPadFoldAdaptor {
137 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
138 return true;
139 }
140 static bool checkPadConstCompliance(OpTy op, Value padConst) {
141 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
142 }
143 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
144 Value padInput, ArrayRef<int64_t> newPad) {
145 rewriter.replaceOpWithNewOp<OpTy>(
146 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
147 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
148 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
149 }
150};
151
152// Pattern attempts to fold a `tosa.pad` operator to a following tensor
153// operation like `tosa.conv2d` by merging the padding associated with the
154// pad operator directly to the implicit padding of the tensor operation.
155// This helps eliminate the explicit padding operator if unused.
156template <typename OpTy, typename AdaptorTy>
157struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
158 using OpRewritePattern<OpTy>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(OpTy tensorOp,
161 PatternRewriter &rewriter) const override {
162 // Check producer is a tosa::PadOp
163 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
164 if (!padOp)
165 return rewriter.notifyMatchFailure(tensorOp,
166 "Producer must be a tosa::PadOp.");
167
168 // Validate that tensor operation has sane padding
169 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
170 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
171 return rewriter.notifyMatchFailure(
172 tensorOp, "Tensor operation padding shall have 4 elements.");
173
174 // Validate tosa::PadOp padding
175 DenseIntElementsAttr padOpPadding;
176 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
177 return rewriter.notifyMatchFailure(
178 tensorOp,
179 "The `padding` input specified on the tosa::PadOp must be constant.");
180 }
181 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
182 // C_after
183 if (padOpPadding.size() != 8)
184 return rewriter.notifyMatchFailure(tensorOp,
185 "Pad padding should have 8 elements.");
186 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
187 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
188 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
189 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
190 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
191 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
192 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
193 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
194
195 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
196 return rewriter.notifyMatchFailure(
197 tensorOp, "Folding padding in N or C dimensions is not supported.");
198
199 // Fold padding from Pad into the tensor operation
200 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
201 SmallVector<int64_t> foldedPad(tensorOpPad.size());
202 foldedPad[0] = padHBefore + tensorOpPad[0];
203 foldedPad[1] = padHAfter + tensorOpPad[1];
204 foldedPad[2] = padWBefore + tensorOpPad[2];
205 foldedPad[3] = padWAfter + tensorOpPad[3];
206
207 // Check kernel related restrictions
208 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
209 return rewriter.notifyMatchFailure(
210 tensorOp, "Padding size not aligned with kernel restrictions.");
211 }
212
213 // Check padding constant restrictions
214 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
215 return rewriter.notifyMatchFailure(
216 tensorOp,
217 "Padding constant is not aligned with operator zero-point.");
218 }
219
220 // Check that padding doesn't grow more than 8K level (8192) for now
221 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
222 return rewriter.notifyMatchFailure(
223 tensorOp, "Padding size more than the 8K level limit.");
224 }
225
226 // Create operator
227 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
228 foldedPad);
229
230 return success();
231 }
232};
233} // namespace
234
235void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
236 MLIRContext *context) {
237 results.add<
238 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
239 context);
240}
241
242void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
243 MLIRContext *context) {
244 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
245 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
246 context);
247}
248
250 : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
252
253 LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
254 PatternRewriter &rewriter) const override {
258 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
259 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
260 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
261 return rewriter.notifyMatchFailure(
262 op, "expected constant kernel, stride, and pad operands");
263
264 auto replacement = tosa::AvgPool2dOp::create(
265 rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
266 op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
267 rewriter.getDenseI64ArrayAttr(stride),
268 rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
269 rewriter.replaceOp(op, replacement.getOutput());
270 return success();
271 }
272};
273
274void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
275 RewritePatternSet &results, MLIRContext *context) {
276 results.add<AvgPool2dAdaptiveToAvgPool2d>(context);
277}
278
279struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
281
282 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
283 PatternRewriter &rewriter) const override {
284 Value input = op.getInput();
285 Value output = op.getOutput();
286 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
287 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
288
289 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
290 return failure();
291 }
292
293 // If the output and input shapes are 1x1, then this is a no op.
294 ArrayRef<int64_t> outputShape = outputType.getShape();
295 if (outputShape[1] != 1 || outputShape[2] != 1) {
296 return failure();
297 }
298
299 ArrayRef<int64_t> inputShape = inputType.getShape();
300 if (inputShape[1] != 1 || inputShape[2] != 1) {
301 return failure();
302 }
303
304 rewriter.replaceOp(op, input);
305 return success();
306 }
307};
308
309void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
310 MLIRContext *context) {
311 results.add<MaxPool2dIsNoOp,
312 FoldPadToTensorOp<tosa::MaxPool2dOp,
313 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
314 context);
315}
316
318 : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
320
321 LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
322 PatternRewriter &rewriter) const override {
326 if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
327 !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
328 !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
329 return rewriter.notifyMatchFailure(
330 op, "expected constant kernel, stride, and pad operands");
331
332 auto replacement = tosa::MaxPool2dOp::create(
333 rewriter, op.getLoc(), op.getType(), op.getInput(),
334 rewriter.getDenseI64ArrayAttr(kernel),
335 rewriter.getDenseI64ArrayAttr(stride),
336 rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
337 rewriter.replaceOp(op, replacement.getOutput());
338 return success();
339 }
340};
341
342void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
343 RewritePatternSet &results, MLIRContext *context) {
344 results.add<MaxPool2dAdaptiveToMaxPool2d>(context);
345}
346
347//===----------------------------------------------------------------------===//
348// Data Layout / Memory Reinterpretation.
349//===----------------------------------------------------------------------===//
350
351struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
352 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
353
354 LogicalResult matchAndRewrite(tosa::ConcatOp op,
355 PatternRewriter &rewriter) const override {
356 if (op.getInput1().size() != 1)
357 return failure();
358 if (op.getInput1().front().getType() != op.getType()) {
359 rewriter
360 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
361 op.getInput1().front())
362 .getResult();
363 return success();
364 }
365
366 rewriter.replaceOp(op, op.getInput1().front());
367 return success();
368 }
369};
370
371struct ConsecutiveConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
372 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
373
374 LogicalResult matchAndRewrite(tosa::ConcatOp op,
375 PatternRewriter &rewriter) const override {
376 // Rewrite consecutive concats on the same axis into a single op.
377 // Keep track of the operands so we are able to construct a new concat
378 // later. Conservatively assume that we double the number of operands when
379 // canonicalizing
380 SmallVector<Value, 8> concatOperands;
381 concatOperands.reserve(2 * op.getNumOperands());
382
383 int32_t maxNumOperands = 0;
384 if (auto targetEnvAttr = tosa::lookupTargetEnv(op))
385 maxNumOperands =
386 getTosaLevelFromEnum(targetEnvAttr.getLevel()).MAX_TENSOR_LIST_SIZE;
387
388 // Find all operands that are foldable concats
389 bool foundRewritableConcat = false;
390 for (Value operand : op.getOperands()) {
391 concatOperands.emplace_back(operand);
392
393 auto producer = operand.getDefiningOp<tosa::ConcatOp>();
394 if (!producer)
395 continue;
396
397 // Not rewritable if axes are not the same
398 if (op.getAxis() != producer.getAxis())
399 continue;
400
401 // Replace the original operand with all incoming operands
402 foundRewritableConcat = true;
403 concatOperands.pop_back();
404 llvm::append_range(concatOperands, producer->getOperands());
405 }
406
407 if (!foundRewritableConcat)
408 return rewriter.notifyMatchFailure(op,
409 "No rewritable concat operand found.");
410
411 if (maxNumOperands > 0 &&
412 concatOperands.size() > static_cast<size_t>(maxNumOperands))
413 return rewriter.notifyMatchFailure(
414 op, "Rewriting would exceed the maximum number of operands for the "
415 "target environment level.");
416
417 rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
418 op, op.getType(), concatOperands, op.getAxisAttr());
419 return success();
420 }
421};
422
423void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
424 MLIRContext *context) {
426}
427
428LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
429 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
430 if (!notOp)
431 return failure();
432 rewriter.modifyOpInPlace(op, [&]() {
433 op.getOperation()->setOperands(
434 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
435 });
436 return success();
437}
438
440 : public OpRewritePattern<tosa::TransposeOp> {
442
443 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
444 PatternRewriter &rewriter) const override {
445 // Input is also TransposeOp - transpose(transpose(A)).
446 auto innerTranspose =
447 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
448 if (!innerTranspose)
449 return rewriter.notifyMatchFailure(transposeOp,
450 "input must be transpose operation");
451
452 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
453 const llvm::ArrayRef<int32_t> innerTransposePerms =
454 innerTranspose.getPerms();
455
456 if (transposePerms.size() != innerTransposePerms.size())
457 return rewriter.notifyMatchFailure(
458 transposeOp,
459 "transpose and inner transpose perms sizes must be equal");
460 if (transposePerms.empty())
461 return rewriter.notifyMatchFailure(
462 transposeOp, "transpose perms sizes must be positive");
463
464 // Consolidate transposes into one transpose.
465 SmallVector<int32_t> perms(transposePerms.size());
466 for (int i = 0, s = transposePerms.size(); i < s; ++i)
467 perms[i] = innerTransposePerms[transposePerms[i]];
468
469 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
470 transposeOp, transposeOp.getResult().getType(),
471 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
472
473 return success();
474 }
475};
476
477// Determines the case when tosa.transpose is a tosa.reshape operation.
478struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
480
481 LogicalResult matchAndRewrite(tosa::TransposeOp op,
482 PatternRewriter &rewriter) const override {
483 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
484 return rewriter.notifyMatchFailure(
485 op, "Src is from transpose, can compose transposes");
486
487 Value result = op.getResult();
488 for (Operation *subop : result.getUsers()) {
489 if (isa_and_nonnull<tosa::TransposeOp>(subop))
490 return rewriter.notifyMatchFailure(
491 op, "Dest is used by transpose, can compose transposes");
492 }
493
494 auto input = op.getInput1();
495 auto inputTy = llvm::cast<ShapedType>(input.getType());
496 if (!inputTy.hasRank())
497 return rewriter.notifyMatchFailure(op, "Unranked input.");
498
499 int64_t numDynDims = 0;
500 for (int i = 0; i < inputTy.getRank(); ++i)
501 if (inputTy.isDynamicDim(i))
502 numDynDims++;
503
504 if (numDynDims > 1)
505 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
506
507 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
508
509 SmallVector<int64_t> nonZeroPerms;
510 nonZeroPerms.reserve(permValues.size());
511 for (auto idx : permValues) {
512 auto sz = inputTy.getDimSize(idx);
513 if (sz != 1)
514 nonZeroPerms.push_back(idx);
515 }
516
517 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
518 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
519 return rewriter.notifyMatchFailure(op,
520 "Transpose changes memory layout.");
521
522 SmallVector<int64_t> newShape;
523 newShape.reserve(inputTy.getRank());
524 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
525 newShape.push_back(inputTy.getDimSize(permValues[i]));
526
527 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
528 op, op.getType(), op.getInput1(),
529 getTosaConstShape(rewriter, op.getLoc(), newShape));
530 return success();
531 }
532};
533
534void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
535 MLIRContext *context) {
536 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
537}
538
539struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
541
542 LogicalResult matchAndRewrite(tosa::ClampOp op,
543 PatternRewriter &rewriter) const override {
544 Value input = op.getInput();
545 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
546 auto inputElementType = inputType.getElementType();
547
548 if (isa<FloatType>(inputElementType)) {
549 // Unlike integer types, floating point types can represent infinity.
550 const auto minClamp =
551 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
552 const auto maxClamp =
553 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
554 const bool isMin = minClamp.isNegInfinity();
555 const bool isMax = maxClamp.isInfinity();
556
557 if (isMin && isMax) {
558 rewriter.replaceOp(op, input);
559 return success();
560 }
561 return failure();
562 }
563
564 // i1 types are boolean in TOSA
565 const bool isBoolean = inputElementType.isInteger(1);
566 if (inputElementType.isUnsignedInteger() || isBoolean) {
567 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
568 .getValue()
569 .getZExtValue();
570 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
571 .getValue()
572 .getZExtValue();
573
574 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
575 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
576 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
577
578 if (minClamp <= intMin && maxClamp >= intMax) {
579 rewriter.replaceOp(op, input);
580 return success();
581 }
582 return failure();
583 }
584
585 if (llvm::isa<IntegerType>(inputElementType)) {
586 const int64_t minClamp =
587 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
588 const int64_t maxClamp =
589 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
590
591 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
592 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
593 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
594
595 if (minClamp <= intMin && maxClamp >= intMax) {
596 rewriter.replaceOp(op, input);
597 return success();
598 }
599 return failure();
600 }
601
602 return failure();
603 }
604};
605
606// Attempts the following transformation:
607//
608// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
609// tensor X the following identity holds:
610//
611// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
612//
613// subject to the following valid NaN propagation semantics:
614// --------------------------------------------
615// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
616// |-------------|--------------|-------------|
617// | PROPAGATE | PROPAGATE | PROPAGATE |
618// | PROPAGATE | IGNORE | IGNORE |
619// | IGNORE | PROPAGATE | INVALID |
620// | IGNORE | IGNORE | IGNORE |
621// |------------------------------------------|
622
623struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
624 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
625
626 // Helper structure to describe the range of a clamp operation.
627 template <typename T>
628 struct ClampRange {
629 ClampRange(const T &start, const T &end) : start(start), end(end) {}
632
633 // Helper function to determine if two Clamp ranges intersect.
634 bool intersects(const ClampRange<T> &otherRange) {
635 return start < otherRange.end && otherRange.start < end;
636 }
637 };
638
639 LogicalResult matchAndRewrite(tosa::ClampOp op,
640 PatternRewriter &rewriter) const override {
641 Value input = op.getInput();
642
643 // Check the input to the CLAMP op is itself a CLAMP.
644 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
645 if (!clampOp)
646 return failure();
647
648 // Check we have a valid NaN propagation combination.
649 const auto opNanMode = op.getNanMode();
650 const auto clampNanMode = clampOp.getNanMode();
651 if (opNanMode == NanPropagationMode::IGNORE &&
652 clampNanMode == NanPropagationMode::PROPAGATE)
653 return failure();
654
655 auto maxValAttr = op.getMaxValAttr();
656 auto minValAttr = op.getMinValAttr();
657 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
658 auto clampOpMinValAttr = clampOp.getMinValAttr();
659
660 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
661 if (auto quantType =
662 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
663 inputEType = getStorageElementTypeFromQuantized(quantType);
664 }
665
666 Attribute newMinValAttr, newMaxValAttr;
667 if (mlir::isa<FloatType>(inputEType)) {
668 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
669 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
670 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
671 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
672
673 // Check we have intersecting ranges.
674 const auto opMinFloat = floatMinValAttr.getValue();
675 const auto opMaxFloat = floatMaxValAttr.getValue();
676 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
677 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
678 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
679 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
680 clampOpMaxFloat);
681 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
682 return failure();
683
684 // Run the transformation.
685 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
686 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
687 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
688 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
689 } else {
690 assert(mlir::isa<IntegerType>(inputEType));
691 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
692 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
693 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
694 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
695
696 if (inputEType.isUnsignedInteger()) {
697 // Check we have intersecting ranges.
698 const auto opMinInt = intMinValAttr.getUInt();
699 const auto opMaxInt = intMaxValAttr.getUInt();
700 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
701 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
702 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
703 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
704 clampOpMaxInt);
705 if (!opRangeIntRange.intersects(clampRangeIntRange))
706 return failure();
707
708 // Run the transformation.
709 auto newMinVal = std::max(opMinInt, clampOpMinInt);
710 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
711 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
712 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
713 } else {
714 // Check we have intersecting ranges.
715 const auto opMinInt = intMinValAttr.getInt();
716 const auto opMaxInt = intMaxValAttr.getInt();
717 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
718 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
719 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
720 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
721 clampOpMaxInt);
722 if (!opRangeIntRange.intersects(clampRangeIntRange))
723 return failure();
724
725 // Run the transformation.
726 auto newMinVal = std::max(opMinInt, clampOpMinInt);
727 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
728 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
729 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
730 }
731 }
732
733 auto newMode = (opNanMode != clampNanMode)
734 ? tosa::NanPropagationMode::IGNORE
735 : opNanMode;
736
737 auto newModeAttr =
738 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
739
740 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
741 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
742 newModeAttr);
743 return success();
744 }
745};
746
747void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
748 MLIRContext *context) {
749 results.add<ClampIsNoOp>(context);
750 results.add<ClampClampOptimization>(context);
751}
752
753struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
754 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
755
756 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
757 PatternRewriter &rewriter) const override {
758 Value sliceInput = sliceOp.getInput1();
759 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
760 if (!concatOp)
761 return rewriter.notifyMatchFailure(
762 sliceOp, "slice input must be concat operation");
763
764 OperandRange inputs = concatOp.getInput1();
765 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
766 if (!concatType || !concatType.hasStaticShape())
767 return rewriter.notifyMatchFailure(
768 sliceOp, "slice input must be a static ranked tensor");
769 int32_t axis = concatOp.getAxis();
770
771 DenseElementsAttr startElems;
772 DenseElementsAttr sizeElems;
773
774 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
775 return rewriter.notifyMatchFailure(
776 sliceOp, "start of slice must be a static ranked shape");
777
778 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
779 return rewriter.notifyMatchFailure(
780 sliceOp, "size of slice must be a static ranked shape");
781
782 llvm::SmallVector<int64_t> sliceStarts =
783 llvm::to_vector(startElems.getValues<int64_t>());
784 llvm::SmallVector<int64_t> sliceSizes =
785 llvm::to_vector(sizeElems.getValues<int64_t>());
786
787 // Validate slice on the concatenated axis. Slicing along this
788 // axis should span only one of the inputs to the concatenate
789 // operation.
790 std::optional<Value> replaceWithSlice;
791 for (auto input : inputs) {
792 auto inputType = dyn_cast<RankedTensorType>(input.getType());
793 if (!inputType || !inputType.hasStaticShape())
794 return rewriter.notifyMatchFailure(
795 sliceOp, "concat input must be a static ranked tensor");
796
797 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
798 inputType.getDimSize(axis)) {
799 auto start_op =
800 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
801 auto size_op =
802 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
803 replaceWithSlice =
804 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
805 input, start_op, size_op)
806 .getResult();
807 break;
808 }
809 sliceStarts[axis] -= inputType.getDimSize(axis);
810 }
811
812 if (!replaceWithSlice)
813 return rewriter.notifyMatchFailure(
814 sliceOp, "corresponding concat input not found for slice");
815
816 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
817 return success();
818 }
819};
820
821struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
822 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
823
824 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
825 PatternRewriter &rewriter) const override {
826 Value sliceInput = sliceOp.getInput1();
827
828 // Check if producer is a PadOp
829 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
830 if (!padOp)
831 return rewriter.notifyMatchFailure(sliceOp,
832 "slice input must be a pad operation");
833
834 // Check PadOp has a single consumer
835 if (!padOp->hasOneUse())
836 return rewriter.notifyMatchFailure(sliceOp,
837 "pad shall have a single consumer");
838
839 // Check input is statically ranked
840 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
841 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
842 if (!inputTy || !padTy || !inputTy.hasRank())
843 return rewriter.notifyMatchFailure(sliceOp,
844 "slice input must be a ranked tensor");
845
846 // Validate and extract tosa::PadOp padding
847 DenseIntElementsAttr paddingElems;
848 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
849 return rewriter.notifyMatchFailure(
850 sliceOp,
851 "`padding` input specified on the tosa::PadOp must be constant.");
852 }
853 llvm::SmallVector<int64_t> padPaddings =
854 llvm::to_vector(paddingElems.getValues<int64_t>());
855
856 // Extract slice parameters
857 DenseElementsAttr startElems;
858 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
859 return rewriter.notifyMatchFailure(
860 sliceOp, "start of slice must be a static ranked shape");
861 llvm::SmallVector<int64_t> sliceStarts =
862 llvm::to_vector(startElems.getValues<int64_t>());
863
864 DenseElementsAttr sizeElems;
865 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
866 return rewriter.notifyMatchFailure(
867 sliceOp, "size of slice must be a static ranked shape");
868 llvm::SmallVector<int64_t> sliceSizes =
869 llvm::to_vector(sizeElems.getValues<int64_t>());
870
871 // Check if dynamic dimensions are sliced
872 const int64_t rank = inputTy.getRank();
873 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
874 const bool isDimDynamic = inputTy.isDynamicDim(i);
875 const bool isDimSliced =
876 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
877
878 return isDimDynamic && isDimSliced;
879 })) {
880 return rewriter.notifyMatchFailure(
881 sliceOp, "axis that are sliced shall be statically known.");
882 }
883
884 // Update the parameters
885 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
886 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
887 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
888 bool updated = false;
889
890 for (int64_t i = 0; i < rank; ++i) {
891 const int64_t padLo = padPaddings[i * 2];
892 const int64_t padHi = padPaddings[i * 2 + 1];
893 const int64_t sliceStart = sliceStarts[i];
894 const int64_t sliceSize = sliceSizes[i];
895 const int64_t sliceEnd = sliceStart + sliceSize;
896
897 // If dimension is dynamic pass-through
898 if (inputTy.isDynamicDim(i)) {
899 newPadPaddings[i * 2] = padLo;
900 newPadPaddings[i * 2 + 1] = padHi;
901 newSliceStarts[i] = sliceStart;
902 continue;
903 }
904
905 // Handle static dimensions
906 const int64_t dimSize = inputTy.getShape()[i];
907 const int64_t dimTotal = padLo + dimSize + padHi;
908
909 // Check slice within bounds
910 if (sliceStart < 0 || sliceEnd > dimTotal)
911 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
912
913 // Compute updated slice start parameter
914 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
915 newSliceStarts[i] = newSliceStart;
916 updated |= newSliceStart != sliceStart;
917
918 // Compute updated pad parameters
919 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
920 const int64_t newPadHi =
921 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
922 newPadPaddings[i * 2] = newPadLo;
923 newPadPaddings[i * 2 + 1] = newPadHi;
924 updated |= (newPadLo != padLo) || (newPadHi != padHi);
925
926 // Calculate new pad output shape
927 newPadShape[i] =
928 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
929 }
930
931 // Check that we actually need to proceed with the rewrite
932 if (!updated)
933 return rewriter.notifyMatchFailure(
934 sliceOp, "terminate condition; nothing to rewrite");
935
936 // Create a PadOp with updated padding
937 auto newPaddingsOp =
938 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
939 auto newPadTy =
940 RankedTensorType::get(newPadShape, inputTy.getElementType());
941 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
942 padOp.getInput1(), newPaddingsOp,
943 padOp.getPadConst());
944
945 // Update SliceOp and point to new PadOp
946 auto newStartOp =
947 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
948 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
949 newPadOp.getResult(), newStartOp,
950 sliceOp.getSize());
951
952 return success();
953 }
954};
955
956// Update size operand of tosa.slice if size has dynamic dims but corresponding
957// output dim is static
959 : public OpRewritePattern<tosa::SliceOp> {
960 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
961
962 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
963 PatternRewriter &rewriter) const override {
964 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
965 if (!resultType.hasRank())
966 return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
967
968 ElementsAttr sizeElems;
969 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
970 return rewriter.notifyMatchFailure(
971 sliceOp, "size of slice must be a static ranked shape");
972 }
973
974 llvm::SmallVector<int64_t> sliceSizes =
975 llvm::to_vector(sizeElems.getValues<int64_t>());
976
977 bool replaceSliceSize{false};
978 // if size op has kInferableDimSize indicating dynamic shape but
979 // corresponding dim on the output is statically known, update size to match
980 // with known output dim shape
981 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
982 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
983 sliceSizes[index] = resultType.getDimSize(index);
984 replaceSliceSize = true;
985 }
986 }
987
988 if (!replaceSliceSize) {
989 return rewriter.notifyMatchFailure(
990 sliceOp, "no dimension of size of slice is dynamic that resolves "
991 "to static output shape");
992 }
993
994 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
995 auto newSliceOp =
996 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
997 sliceOp.getInput1(), sliceOp.getStart(), size_op);
998
999 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
1000 return success();
1001 }
1002};
1003
1004void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005 MLIRContext *context) {
1006 results.add<ConcatSliceOptimization, PadSliceOptimization,
1007 SliceDynamicSizeCanonicalization>(context);
1008}
1009
1011 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
1012
1013 LogicalResult matchAndRewrite(tosa::CastOp castOp,
1014 PatternRewriter &rewriter) const override {
1015 const Value castInput = castOp.getInput();
1016 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
1017 if (!innerCastOp)
1018 return rewriter.notifyMatchFailure(castOp,
1019 "input must be cast operation");
1020
1021 const Value innerCastInput = innerCastOp.getInput();
1022
1023 const ShapedType innerInputType =
1024 llvm::cast<ShapedType>(innerCastInput.getType());
1025 const ShapedType innerOutputType =
1026 llvm::cast<ShapedType>(innerCastOp.getType());
1027 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
1028
1029 const Type innerInputElemType = innerInputType.getElementType();
1030 const Type innerOutputElemType = innerOutputType.getElementType();
1031 const Type outerOutputElemType = outerOutputType.getElementType();
1032
1033 const SmallVector<Type, 3> types = {innerInputElemType, innerOutputElemType,
1034 outerOutputElemType};
1035
1036 if (llvm::any_of(types, [](const Type type) {
1037 // Support a specific set of floating point types since we need to be
1038 // careful in not introducing unsupported type combinations
1039 return !(type.isInteger() ||
1040 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1041 Float16Type, Float32Type>(type));
1042 }))
1043 return rewriter.notifyMatchFailure(
1044 castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
1045 "supported");
1046
1047 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
1048 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
1049 return rewriter.notifyMatchFailure(
1050 castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1051 "legal in TOSA");
1052 }
1053
1054 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1055 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1056 return rewriter.notifyMatchFailure(
1057 castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1058 "legal in TOSA");
1059 }
1060
1061 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1062 outerOutputElemType.isInteger()) {
1063 return rewriter.notifyMatchFailure(
1064 castOp, "avoid introducing fp8 -> integer casts which are not "
1065 "legal in TOSA");
1066 }
1067
1068 if (innerInputElemType.isInteger() &&
1069 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1070 return rewriter.notifyMatchFailure(
1071 castOp, "avoid introducing integer -> fp8 casts which are not "
1072 "legal in TOSA");
1073 }
1074
1075 if (llvm::isa<Float16Type>(innerInputElemType) &&
1076 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1077 return rewriter.notifyMatchFailure(
1078 castOp, "avoid introducing fp16 -> bf16 casts which are not "
1079 "legal in TOSA");
1080 }
1081
1082 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1083 llvm::isa<Float16Type>(outerOutputElemType)) {
1084 return rewriter.notifyMatchFailure(
1085 castOp, "avoid introducing bf16 -> fp16 casts which are not "
1086 "legal in TOSA");
1087 }
1088
1089 const auto isIntegerOneOfWidth = [](Type type, size_t bitwidth1,
1090 size_t bitwidth2) {
1091 return type.isInteger(bitwidth1) || type.isInteger(bitwidth2);
1092 };
1093
1094 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1095 outerOutputElemType.isInteger(64)) {
1096 return rewriter.notifyMatchFailure(
1097 castOp, "avoid introducing i8/i16 -> i64 casts which are not "
1098 "legal in TOSA");
1099 }
1100
1101 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1102 !outerOutputElemType.isInteger()) {
1103 return rewriter.notifyMatchFailure(
1104 castOp, "avoid introducing bool/i64 to float casts which are not "
1105 "supported in all versions of TOSA");
1106 }
1107
1108 if (!innerInputElemType.isInteger() &&
1109 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1110 return rewriter.notifyMatchFailure(
1111 castOp, "avoid introducing float to bool/i64 casts which are not "
1112 "supported in all versions of TOSA");
1113 }
1114
1115 // Check that the cast we're considering for removal is non-narrowing
1116 if (isNarrowingCast(innerInputType, innerOutputType))
1117 return rewriter.notifyMatchFailure(castOp,
1118 "inner cast operation is narrowing");
1119
1120 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
1121 innerCastInput);
1122
1123 return success();
1124 }
1125
1126 bool supportsNaN(const llvm::fltSemantics &semantics) const {
1127 return semantics.nonFiniteBehavior !=
1128 llvm::fltNonfiniteBehavior::FiniteOnly;
1129 }
1130
1131 bool supportsInf(const llvm::fltSemantics &semantics) const {
1132 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1133 }
1134
1135 bool isNarrowingCast(const ShapedType inType,
1136 const ShapedType outType) const {
1137
1138 if (inType.getElementType().isInteger() &&
1139 outType.getElementType().isInteger()) {
1140
1141 const auto inTypeSignedness =
1142 cast<IntegerType>(inType.getElementType()).getSignedness();
1143 const auto outTypeSignedness =
1144 cast<IntegerType>(outType.getElementType()).getSignedness();
1145
1146 return (inTypeSignedness != outTypeSignedness ||
1147 inType.getElementTypeBitWidth() >
1148 outType.getElementTypeBitWidth());
1149 }
1150
1151 if (inType.getElementType().isFloat() &&
1152 outType.getElementType().isFloat()) {
1153
1154 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1155 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1156 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1157 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1158
1159 // If the list of supported types needs to be updated in the future, the
1160 // check down below will need to be revised, for example to account for
1161 // unsigned floating point types, or types that use negative zero as the
1162 // representation for NaN.
1163 [[maybe_unused]] const auto isSupported = [](Type elemType) {
1164 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1165 Float16Type, Float32Type>(elemType);
1166 };
1167
1168 assert(isSupported(inElemTy) &&
1169 "unsupported input element type in isNarrowingCast");
1170 assert(isSupported(outElemTy) &&
1171 "unsupported output element type in isNarrowingCast");
1172
1173 return (
1174 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1175 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1176 inTypeSemantics.precision > outTypeSemantics.precision ||
1177 (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
1178 (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
1179 }
1180
1181 // While some cases of int -> float casts can be non-narrowing, consider
1182 // them narrowing for the purposes of this optimization
1183 return true;
1184 }
1185};
1186
1187void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.add<NonNarrowingCastsOptimization>(context);
1190}
1191
1193 : public OpRewritePattern<tosa::CastToBlockScaledOp> {
1194 using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
1195
1196 LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
1197 PatternRewriter &rewriter) const override {
1198 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1199 auto castFromBlockScaledOp =
1200 castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
1201 if (!castFromBlockScaledOp)
1202 return rewriter.notifyMatchFailure(
1203 castToBlockScaledOp,
1204 "input must be cast_from_block_scaled operation");
1205
1206 const Value innerData = castFromBlockScaledOp.getInputData();
1207 const Value innerScale = castFromBlockScaledOp.getInputScale();
1208 const auto innerDataTy = llvm::cast<ShapedType>(innerData.getType());
1209 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.getType());
1210
1211 const Value outerData = castToBlockScaledOp.getOutputData();
1212 const Value outerScale = castToBlockScaledOp.getOutputScale();
1213 const auto outerDataTy = llvm::cast<ShapedType>(outerData.getType());
1214 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.getType());
1215
1216 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1217 return rewriter.notifyMatchFailure(
1218 castToBlockScaledOp,
1219 "inputs types to cast_from_block_scaled operation must match output "
1220 "types to cast_to_block_scaled");
1221 }
1222
1223 if (castFromBlockScaledOp.getBlockSize() !=
1224 castToBlockScaledOp.getBlockSize()) {
1225 return rewriter.notifyMatchFailure(
1226 castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
1227 "cast_to_block_scaled must match");
1228 }
1229
1230 rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
1231
1232 return success();
1233 }
1234};
1235
1236void CastToBlockScaledOp::getCanonicalizationPatterns(
1237 RewritePatternSet &results, MLIRContext *context) {
1238 results.add<CancellingBlockScaledCastsOptimization>(context);
1239}
1240
1241//===----------------------------------------------------------------------===//
1242// Operator Folders.
1243//===----------------------------------------------------------------------===//
1244
1245template <typename Folder>
1246static DenseElementsAttr
1248 bool foldDenseValues = false) {
1249 if (!lhs || !rhs)
1250 return {};
1251
1252 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1253 return {};
1254
1255 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
1256 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
1257 if (lETy != rETy)
1258 return {};
1259
1260 if (lhs.isSplat() && rhs.isSplat()) {
1261 if (isa<FloatType>(lETy)) {
1262 const APFloat l = lhs.getSplatValue<APFloat>();
1263 const APFloat r = rhs.getSplatValue<APFloat>();
1264 const auto maybeResult = Folder::fold(l, r);
1265 if (failed(maybeResult))
1266 return {};
1267 return DenseElementsAttr::get(returnTy, maybeResult.value());
1268 }
1269
1270 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1271 const APInt l = lhs.getSplatValue<APInt>();
1272 const APInt r = rhs.getSplatValue<APInt>();
1273 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1274 if (failed(maybeResult))
1275 return {};
1276 return DenseElementsAttr::get(returnTy, maybeResult.value());
1277 }
1278 }
1279
1280 if (foldDenseValues) {
1281 assert(lETy.isIntOrIndex() &&
1282 "Only integer types are currently supported.");
1283 SmallVector<APInt> resultValues;
1284 for (auto [l, r] :
1285 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
1286 const auto maybeResult = Folder::fold(l, r, false);
1287 if (failed(maybeResult))
1288 return {};
1289 resultValues.push_back(maybeResult.value());
1290 }
1291 return DenseElementsAttr::get(returnTy, resultValues);
1292 }
1293
1294 return {};
1295}
1296
1297template <typename Folder>
1298static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
1299 bool foldDenseValues = false) {
1300 if (!val)
1301 return {};
1302
1303 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1304 return {};
1305
1306 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
1307
1308 if (val.isSplat()) {
1309 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1310 const APInt v = val.getSplatValue<APInt>();
1311 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1312 if (failed(maybeResult))
1313 return {};
1314 return DenseElementsAttr::get(returnTy, maybeResult.value());
1315 }
1316 }
1317
1318 if (foldDenseValues) {
1319 mlir::Type elemTy = val.getElementType();
1320 if (elemTy.isIntOrIndex()) {
1321 SmallVector<APInt> resultValues;
1322 for (auto const &v : val.getValues<APInt>()) {
1323 const auto maybeResult = Folder::fold(v, false);
1324 if (failed(maybeResult))
1325 return {};
1326 resultValues.push_back(maybeResult.value());
1327 }
1328 return DenseElementsAttr::get(returnTy, resultValues);
1329 }
1330 }
1331
1332 // Folding arbitrarily sized tensor operations is not supported
1333 return {};
1334}
1335
1336static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
1337 DenseIntElementsAttr dense{};
1338 if (!matchPattern(v, m_Constant(&dense)))
1339 return failure();
1340
1341 assert(dense.isSplat());
1342 APInt a = dense.getSplatValue<APInt>();
1343 return a.getSExtValue();
1344}
1345
1347 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1348 const bool isUnsigned) {
1349 bool overflow;
1350 const APInt result =
1351 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1352 if (overflow)
1353 return failure();
1354 return result;
1355 }
1356
1357 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1358 return lhs + rhs;
1359 }
1360};
1361
1363 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1364 const bool isUnsigned) {
1365 bool overflow;
1366 const APInt result =
1367 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1368 if (overflow)
1369 return failure();
1370 return result;
1371 }
1372
1373 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1374 return lhs - rhs;
1375 }
1376};
1377
1379 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1380 const bool isUnsigned) {
1381
1382 const unsigned originalWidth = lhs.getBitWidth();
1383
1384 // Check same type
1385 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1386 return failure();
1387 }
1388
1389 // If either is `0`
1390 if (lhs == 0 || rhs == 0)
1391 return APInt::getZero(originalWidth);
1392
1393 bool overflow = false;
1394 APInt const result =
1395 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1396
1397 if (overflow)
1398 return failure();
1399
1400 return result.trunc(originalWidth);
1401 }
1402
1403 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1404 return lhs * rhs;
1405 }
1406};
1407
1408static bool signsDiffer(const APInt &a, const APInt &b) {
1409 return a.isNegative() != b.isNegative();
1410}
1411
1412template <bool Ceil>
1414 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1415 bool isUnsigned) {
1416 if (lhs.getBitWidth() != rhs.getBitWidth())
1417 return failure();
1418 if (rhs.isZero())
1419 return failure();
1420
1421 if (isUnsigned) {
1422 APInt q{};
1423 APInt r{};
1424 APInt::udivrem(lhs, rhs, q, r);
1425 if (!r.isZero() && Ceil) {
1426 return q + 1;
1427 }
1428 return q;
1429 }
1430
1431 // Signed: start from trunc-toward-zero, then adjust to ceil.
1432 bool overflow{false};
1433 APInt const q = lhs.sdiv_ov(rhs, overflow);
1434 if (overflow)
1435 return failure();
1436 APInt const r = lhs.srem(rhs);
1437
1438 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1439 // Same sign => exact quotient is positive; trunc is below ceil =>
1440 // increment q.
1441 return q + 1;
1442 }
1443 return q;
1444 }
1445
1446 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1447 return lhs / rhs;
1448 }
1449};
1450
1452 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1453 bool isUnsigned) {
1454 if (lhs.getBitWidth() != rhs.getBitWidth())
1455 return failure();
1456 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1457 return failure();
1458
1459 if (isUnsigned) {
1460 return lhs.urem(rhs);
1461 }
1462
1463 return lhs.srem(rhs);
1464 }
1465
1466 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1467 auto t = lhs;
1468 auto const r = t.mod(rhs);
1469 if (llvm::APFloatBase::opStatus::opOK == r) {
1470 return t;
1471 }
1472 return failure();
1473 }
1474};
1475
1477 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1478 bool isUnsigned) {
1479 if (lhs.getBitWidth() != rhs.getBitWidth())
1480 return failure();
1481 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1482 }
1483
1484 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1485 return lhs >= rhs ? lhs : rhs;
1486 }
1487};
1488
1490 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1491 bool isUnsigned) {
1492 if (lhs.getBitWidth() != rhs.getBitWidth())
1493 return failure();
1494 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1495 }
1496
1497 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1498 return lhs <= rhs ? lhs : rhs;
1499 }
1500};
1501
1503 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1504 auto const numBits = value.getBitWidth();
1505 if (isUnsigned) {
1506 auto const zextv = value.getZExtValue();
1507 if (zextv >= numBits)
1508 return failure();
1509 return APInt::getOneBitSet(numBits, zextv);
1510 }
1511 auto const sextv = value.getSExtValue();
1512 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1513 return failure();
1514 return APInt::getOneBitSet(numBits, sextv);
1515 }
1516};
1517
1519 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1520 if (!value.isStrictlyPositive())
1521 return failure();
1522 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1523 }
1524};
1525
1527 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1528 if (!value.isStrictlyPositive())
1529 return failure();
1530 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1531 }
1532};
1533
1535 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1536 const bool isUnsigned) {
1537 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1538 }
1539
1540 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1541 return APInt(1, lhs > rhs);
1542 }
1543};
1544
1546 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1547 const bool isUnsigned) {
1548 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1549 }
1550
1551 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1552 return APInt(1, lhs >= rhs);
1553 }
1554};
1555
1557 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1558 const bool isUnsigned) {
1559 return APInt(1, lhs == rhs);
1560 }
1561
1562 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1563 return APInt(1, lhs == rhs);
1564 }
1565};
1566
1567static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1568 if (llvm::isa<FloatType>(elemType))
1569 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1570 if (llvm::isa<IntegerType>(elemType))
1571 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1572 return false;
1573}
1574
1575static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1576 if (llvm::isa<FloatType>(elemType))
1577 return val && val.isSplat() &&
1578 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1579 if (llvm::isa<IntegerType>(elemType)) {
1580 const int64_t shifted = 1LL << shift;
1581 return val && val.isSplat() &&
1582 val.getSplatValue<APInt>().getSExtValue() == shifted;
1583 }
1584 return false;
1585}
1586
1587OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1588 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1589 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1590 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1591 if (!lhsTy || !rhsTy || !resultTy)
1592 return {};
1593
1594 // Cannot create an ElementsAttr from non-int/float/index types
1595 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1596 !rhsTy.getElementType().isIntOrIndexOrFloat())
1597 return {};
1598
1599 auto resultETy = resultTy.getElementType();
1600 auto lhsAttr =
1601 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1602 auto rhsAttr =
1603 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1604
1605 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1608 return getInput1();
1609 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1610 return getInput2();
1611
1612 if (!lhsAttr || !rhsAttr)
1613 return {};
1614
1615 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1616}
1617
1618OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1619 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1620 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1621 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1622 !outputTy.hasStaticShape())
1623 return {};
1624
1625 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1626 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1627 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1628 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1629 return DenseElementsAttr::get(outputTy, zero);
1630 }
1631
1632 return {};
1633}
1634
1635OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1636 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1637 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1638 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1639 if (!lhsTy || !rhsTy || !resultTy)
1640 return {};
1641 if (lhsTy.getElementType() != rhsTy.getElementType())
1642 return {};
1643
1644 // IntDivOp inputs must be integer type, no need to check for quantized
1645 // type
1646 auto resultETy = resultTy.getElementType();
1647 auto lhsAttr =
1648 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1649 auto rhsAttr =
1650 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1651 if (lhsAttr && lhsAttr.isSplat()) {
1652 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1653 lhsAttr.getSplatValue<APInt>().isZero())
1654 return lhsAttr.resizeSplat(resultTy);
1655 }
1656
1657 if (rhsAttr && rhsAttr.isSplat()) {
1658 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1659 lhsTy.getShape(), rhsTy.getShape());
1660 if (isBroadcastable && lhsTy == resultTy &&
1661 llvm::isa<IntegerType>(resultETy) &&
1662 rhsAttr.getSplatValue<APInt>().isOne())
1663 return getInput1();
1664 }
1665
1666 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1667 llvm::isa<IntegerType>(resultETy)) {
1668 APInt l = lhsAttr.getSplatValue<APInt>();
1669 APInt r = rhsAttr.getSplatValue<APInt>();
1670 if (!r.isZero()) {
1671 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1672 auto const result =
1673 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1674 if (failed(result))
1675 return {};
1676 return DenseElementsAttr::get(resultTy, result.value());
1677 }
1678 }
1679
1680 return {};
1681}
1682
1683namespace {
1684// calculate lhs * rhs >> shift according to TOSA Spec
1685// return nullopt if result is not in range of int32_t when shift > 0
1686std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1687 unsigned bitwidth) {
1688 bool overflow = false;
1689 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1690
1691 if (overflow)
1692 return std::nullopt;
1693
1694 if (shift > 0) {
1695 auto round = APInt(64, 1) << (shift - 1);
1696 result += round;
1697 result.ashrInPlace(shift);
1698 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1699 // maximum_s<i32_t>())
1700 if (!(result.getSExtValue() >= INT32_MIN &&
1701 result.getSExtValue() <= INT32_MAX)) {
1702 // REQUIRE failed
1703 return std::nullopt;
1704 }
1705 }
1706
1707 return result.trunc(bitwidth);
1708}
1709
1710DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1711 RankedTensorType ty, int32_t shift) {
1712 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1713 if (llvm::isa<IntegerType>(ty.getElementType())) {
1714 APInt l = lhs.getSplatValue<APInt>();
1715 APInt r = rhs.getSplatValue<APInt>();
1716
1717 if (shift == 0) {
1718 return DenseElementsAttr::get(ty, l * r);
1719 }
1720
1721 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1722 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1723 if (!result)
1724 return {};
1725 return DenseElementsAttr::get(ty, result.value());
1726 }
1727
1728 if (llvm::isa<FloatType>(ty.getElementType())) {
1729 APFloat l = lhs.getSplatValue<APFloat>();
1730 APFloat r = rhs.getSplatValue<APFloat>();
1731 APFloat result = l * r;
1732 return DenseElementsAttr::get(ty, result);
1733 }
1734 }
1735
1736 return {};
1737}
1738} // namespace
1739
1740OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1741 auto lhs = getInput1();
1742 auto rhs = getInput2();
1743 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1744 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1745 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1746 if (!lhsTy || !rhsTy || !resultTy)
1747 return {};
1748
1749 auto resultETy = resultTy.getElementType();
1750 auto lhsAttr =
1751 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1752 auto rhsAttr =
1753 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1754
1755 // Result right shift on i32_t data type only. For simplification,
1756 // synthesize a zero shift for other data type.
1757 int32_t shift = 0;
1758 if (resultETy.isInteger(32)) {
1759 ElementsAttr shift_elem;
1760 if (getShift().getImpl()) {
1761 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1762 // cannot be folded when the shift value is unknown.
1763 return {};
1764 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1765 }
1766 }
1767
1768 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1769 resultTy.hasStaticShape())
1770 // constant values can only be resized if resulting type is static
1771 return lhsAttr.resizeSplat(resultTy);
1772 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1773 resultTy.hasStaticShape())
1774 return rhsAttr.resizeSplat(resultTy);
1775
1776 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1777 lhsTy.getShape(), rhsTy.getShape());
1778 if (isBroadcastable && rhsTy == resultTy &&
1779 isSplatOne(resultETy, lhsAttr, shift))
1780 return rhs;
1781 if (isBroadcastable && lhsTy == resultTy &&
1782 isSplatOne(resultETy, rhsAttr, shift))
1783 return lhs;
1784
1785 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1786}
1787
1788OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1789 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1790 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1791 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1792 if (!lhsTy || !rhsTy || !resultTy)
1793 return {};
1794
1795 // Cannot create an ElementsAttr from non-int/float/index types
1796 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1797 !rhsTy.getElementType().isIntOrIndexOrFloat())
1798 return {};
1799
1800 auto resultETy = resultTy.getElementType();
1801 auto lhsAttr =
1802 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1803 auto rhsAttr =
1804 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1805
1806 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1807 lhsTy.getShape(), rhsTy.getShape());
1808 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1809 return getInput1();
1810
1811 if (!lhsAttr || !rhsAttr)
1812 return {};
1813
1814 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1815}
1816
1817OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1818 auto resultTy = llvm::cast<ShapedType>(getType());
1819 auto lhsAttr =
1820 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1821 auto rhsAttr =
1822 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1823
1824 if (!lhsAttr || !rhsAttr)
1825 return {};
1826
1827 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1828}
1829
1830OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1831 auto resultTy = llvm::cast<ShapedType>(getType());
1832 auto lhsAttr =
1833 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1834 auto rhsAttr =
1835 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1836
1837 if (!lhsAttr || !rhsAttr)
1838 return {};
1839
1840 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1841}
1842
1843OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1844 auto resultTy = llvm::cast<ShapedType>(getType());
1845 auto lhsAttr =
1846 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1847 auto rhsAttr =
1848 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1849 Value lhs = getInput1();
1850 Value rhs = getInput2();
1851 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1852
1853 // If we are comparing an integer value to itself it is always true. We
1854 // can not do this with float due to float values.
1855 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1856 resultTy.hasStaticShape() && lhs == rhs) {
1857 return DenseElementsAttr::get(resultTy, true);
1858 }
1859
1860 if (!lhsAttr || !rhsAttr)
1861 return {};
1862
1863 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1864}
1865
1866OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1867 if (getInput().getType() == getType())
1868 return getInput();
1869
1870 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1871 if (!operand)
1872 return {};
1873
1874 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1875 auto outTy = llvm::cast<ShapedType>(getType());
1876 if (!outTy.hasRank() || !outTy.hasStaticShape())
1877 return {};
1878 auto inETy = inTy.getElementType();
1879 auto outETy = outTy.getElementType();
1880
1881 if (operand.isSplat()) {
1882 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1883 bool overflow;
1884 auto splatVal = operand.getSplatValue<APFloat>();
1885 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1886 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1887 &overflow);
1888 return SplatElementsAttr::get(outTy, splatVal);
1889 }
1890
1891 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1892 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1893 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1894 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1895 llvm::RoundingMode::NearestTiesToEven);
1896 return SplatElementsAttr::get(outTy, splatVal);
1897 }
1898
1899 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1900 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1901 auto intVal = APSInt(
1902 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1903 auto floatVal = operand.getSplatValue<APFloat>();
1904 bool exact;
1905 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1906 &exact);
1907 return SplatElementsAttr::get(outTy, intVal);
1908 }
1909
1910 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1911 const auto inIntType = llvm::cast<IntegerType>(inETy);
1912 auto unsignIn = inIntType.isUnsignedInteger();
1913 bool trunc =
1914 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1915 auto intVal = operand.getSplatValue<APInt>();
1916 auto bitwidth = outETy.getIntOrFloatBitWidth();
1917
1918 // i1 types are boolean in TOSA
1919 if (outETy.isInteger(1)) {
1920 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1921 } else if (trunc) {
1922 intVal = intVal.trunc(bitwidth);
1923 } else if (unsignIn || inIntType.isInteger(1)) {
1924 intVal = intVal.zext(bitwidth);
1925 } else {
1926 intVal = intVal.sext(bitwidth);
1927 }
1928
1929 return SplatElementsAttr::get(outTy, intVal);
1930 }
1931 }
1932
1933 return {};
1934}
1935
1936OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1937
1938OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1939
1940#define REDUCE_FOLDER(OP) \
1941 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1942 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1943 if (!inputTy.hasRank()) \
1944 return {}; \
1945 if (inputTy != getType()) \
1946 return {}; \
1947 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1948 return getInput(); \
1949 return {}; \
1950 }
1951
1952REDUCE_FOLDER(ReduceAllOp)
1953REDUCE_FOLDER(ReduceAnyOp)
1954REDUCE_FOLDER(ReduceMaxOp)
1955REDUCE_FOLDER(ReduceMinOp)
1956REDUCE_FOLDER(ReduceProductOp)
1957REDUCE_FOLDER(ReduceSumOp)
1958#undef REDUCE_FOLDER
1959
1960OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1961 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1962 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1963
1964 if (!inputTy || !outputTy)
1965 return {};
1966
1967 // Fold when the input and output types are the same. This is only safe
1968 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1969 // dimensions, there may still be a productive reshape.
1970 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1971 return getInput1();
1972
1973 // reshape(reshape(x)) -> reshape(x)
1974 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1975 getInput1().getDefiningOp())) {
1976 getInput1Mutable().assign(reshapeOp.getInput1());
1977 return getResult();
1978 }
1979
1980 // Cannot create an ElementsAttr from non-int/float/index types
1981 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1982 return {};
1983
1984 // reshape(const(x)) -> const(reshape-attr(x))
1985 if (auto operand =
1986 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1987 // Constants must have static shape.
1988 if (!outputTy.hasStaticShape())
1989 return {};
1990
1991 // Okay to duplicate splat constants.
1992 if (operand.isSplat())
1993 return SplatElementsAttr::get(outputTy,
1994 operand.getSplatValue<Attribute>());
1995
1996 // Don't duplicate other constants.
1997 if (!getInput1().hasOneUse())
1998 return {};
1999
2001 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
2002 return {};
2003
2004 return operand.reshape(
2005 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
2006 }
2007
2008 return {};
2009}
2010
2011OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
2012 // If the pad is all zeros we can fold this operation away.
2013 if (adaptor.getPadding() && getInput1().getType() == getType()) {
2014 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
2015 if (densePad && densePad.isSplat() &&
2016 densePad.getSplatValue<APInt>().isZero()) {
2017 return getInput1();
2018 }
2019 }
2020
2021 return {};
2022}
2023
2024// Fold away cases where a tosa.resize operation returns a copy
2025// of the input image.
2026OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
2027 auto scaleAttr =
2028 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
2029 auto offsetAttr =
2030 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
2031 auto borderAttr =
2032 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
2033 if (!scaleAttr || !offsetAttr || !borderAttr) {
2034 return {};
2035 }
2036
2037 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
2038 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
2039 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
2040 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
2041 return {};
2042 }
2043
2044 // Check unit scaling.
2045 if (scale[0] != scale[1] || scale[2] != scale[3]) {
2046 return {};
2047 }
2048
2049 // There should be no offset.
2050 if (offset[0] != 0 || offset[1] != 0) {
2051 return {};
2052 }
2053
2054 // There should be no border.
2055 if (border[0] != 0 || border[1] != 0) {
2056 return {};
2057 }
2058
2059 return foldToInputIfTypeMatches(getType(), getInput());
2060}
2061
2062OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2063 auto operand = getInput1();
2064 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2065 auto axis = getAxis();
2066 // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
2067 const bool isSplatInput =
2068 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2069 if (!operandTy.hasRank() ||
2070 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2071 return {};
2072 return foldToInputIfTypeMatches(getType(), operand);
2073}
2074
2075OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2076 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2077 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
2078
2079 if (!inputTy || !outputTy)
2080 return {};
2081
2082 if (inputTy == outputTy && inputTy.hasStaticShape())
2083 return getInput1();
2084
2085 // Check if this is a no-op slice (starts at 0 and size matches input)
2086
2087 DenseElementsAttr startElems;
2088 if (!matchPattern(getStart(), m_Constant(&startElems)))
2089 return {};
2090
2091 // Check if all start values are zero
2092 bool startIsZeros =
2093 llvm::all_of(startElems.getValues<APInt>(),
2094 [](const APInt &val) { return val.isZero(); });
2095
2096 if (startIsZeros) {
2097
2098 // Check if size matches input shape
2099 DenseElementsAttr sizeElems;
2100 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
2101 return {};
2102
2103 auto inputShape = inputTy.getShape();
2104 auto sizeValues = sizeElems.getValues<APInt>();
2105
2106 bool sizeMatchesInput = true;
2107 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2108 int64_t size = sizeVal.getSExtValue();
2109
2110 if (inputTy.isDynamicDim(i)) {
2111 // For dynamic dimensions, check for kInferableDimSize indicating full
2112 // dimension is sliced
2113 if (size != kInferableDimSize) {
2114 sizeMatchesInput = false;
2115 break;
2116 }
2117 } else {
2118 // For static dimensions, check that size must match exactly or be
2119 // kInferableDimSize indicating full dimension is sliced
2120 if (size != kInferableDimSize && size != inputShape[i]) {
2121 sizeMatchesInput = false;
2122 break;
2123 }
2124 }
2125 }
2126
2127 if (sizeMatchesInput)
2128 return getInput1();
2129 }
2130
2131 // The following checks require the input to be a constant
2132 if (!adaptor.getInput1())
2133 return {};
2134
2135 // Cannot create an ElementsAttr from non-int/float/index types
2136 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2137 !outputTy.getElementType().isIntOrIndexOrFloat())
2138 return {};
2139
2140 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2141 if (operand.isSplat() && outputTy.hasStaticShape()) {
2142 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
2143 }
2144
2145 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2146 outputTy.getNumElements() == 1) {
2147 llvm::SmallVector<uint64_t> indices =
2148 llvm::to_vector(startElems.getValues<uint64_t>());
2149 if (auto values = operand.tryGetValues<Attribute>())
2150 return SplatElementsAttr::get(outputTy, (*values)[indices]);
2151 }
2152
2153 return {};
2154}
2155
2156OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2157 const Value pred = getPred();
2158 const Value onTrue = getOnTrue();
2159 const Value onFalse = getOnFalse();
2160
2161 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
2162 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
2163 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
2164 if (!predTy || !onTrueTy || !onFalseTy)
2165 return {};
2166
2167 const Type resultTy = getType();
2168
2169 const ArrayRef<int64_t> predShape = predTy.getShape();
2170 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2171
2172 if (onTrue == onFalse && onTrueTy == resultTy &&
2173 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
2174 return onTrue;
2175
2176 auto predicate =
2177 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2178 if (!predicate)
2179 return {};
2180 if (!predicate.isSplat())
2181 return {};
2182
2183 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2184
2185 SmallVector<SmallVector<int64_t>, 3> shapes;
2186 shapes.emplace_back(predShape);
2187 shapes.emplace_back(onTrueShape);
2188 shapes.emplace_back(onFalseTy.getShape());
2189 const bool isBroadcastable =
2191
2192 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
2193 return onTrue;
2194 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
2195 return onFalse;
2196 return {};
2197}
2198
2199OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2200 if (getInput1().getType() == getType()) {
2201 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2202 adaptor.getMultiples())) {
2203 if (multiples.isSplat() &&
2204 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2205 return getInput1();
2206 if (auto int_array_attr =
2207 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2208 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2209 [](APInt v) { return v.getSExtValue() == 1; }))
2210 return getInput1();
2211 }
2212 }
2213 }
2214 return {};
2215}
2216
2217OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2218 auto resultTy = llvm::cast<ShapedType>(getType());
2219
2220 // Transposing splat values just means reshaping.
2221 if (auto input =
2222 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2223 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2224 input.getType().getElementType() == resultTy.getElementType())
2225 return input.reshape(resultTy);
2226 }
2227
2228 // Transpose is not the identity transpose.
2229 const llvm::ArrayRef<int32_t> perms = getPerms();
2230
2231 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2232 return {};
2233
2234 return foldToInputIfTypeMatches(getType(), getInput1());
2235}
2236
2237OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2238 // Element-wise negate(negate(x)) = x
2239 // iff all zero points are constant 0
2240 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2241 if (!definingOp) {
2242 // defining op of input1 is not a negate, cannot fold
2243 return {};
2244 }
2245
2246 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2247 failed(maybeIZp) || *maybeIZp != 0) {
2248 // input1 zero point is not constant 0, cannot fold
2249 return {};
2250 }
2251 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2252 failed(maybeOZp) || *maybeOZp != 0) {
2253 // output zero point is not constant 0, cannot fold
2254 return {};
2255 }
2256 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2257 failed(maybeIZp) || *maybeIZp != 0) {
2258 // definingOp's input1 zero point is not constant 0, cannot fold
2259 return {};
2260 }
2261 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2262 failed(maybeOZp) || *maybeOZp != 0) {
2263 // definingOp's output zero point is not constant 0, cannot fold
2264 return {};
2265 }
2266
2267 return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
2268}
2269
2270OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2271 auto input = getInput1();
2272 // Element-wise abs(abs(x)) = abs(x)
2273 if (input.getDefiningOp<tosa::AbsOp>())
2274 return foldToInputIfTypeMatches(getType(), input);
2275
2276 return {};
2277}
2278
2279OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2280 auto input = adaptor.getInput1();
2281
2282 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2283 // Fold splat inputs only.
2284 if (!inputAttr || !inputAttr.isSplat())
2285 return {};
2286
2287 auto shapeType = llvm::cast<ShapedType>(getType());
2288 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2289 return {};
2290 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2291 auto floatVal = inputAttr.getSplatValue<APFloat>();
2292 return DenseElementsAttr::get(shapeType,
2293 ReciprocalOp::calcOneElement(floatVal));
2294 }
2295
2296 return {};
2297}
2298
2299template <typename Op, typename OpFoldAdaptor>
2301 auto input1ConstShape =
2302 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2303 if (!input1ConstShape)
2304 return {};
2305
2306 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2307
2308 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2309 /*foldDenseValues=*/true);
2310}
2311
2312template <typename Op, typename OpFoldAdaptor>
2314 auto input1ConstShape =
2315 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2316 auto input2ConstShape =
2317 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2318 if (!input1ConstShape || !input2ConstShape)
2319 return {};
2320
2321 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2322 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2323
2324 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2325 input1Attr.getType(),
2326 /*foldDenseValues=*/true);
2327}
2328
2329OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2330 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2331 if (!inputTy || !inputTy.hasRank())
2332 return {};
2333 const int32_t axis = getAxis();
2334 const int64_t dimSize = inputTy.getDimSize(axis);
2335 if (ShapedType::isDynamic(dimSize))
2336 return {};
2337
2338 OpBuilder builder(getContext());
2339 const auto resultAttrTy =
2340 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2341 return DenseElementsAttr::get(resultAttrTy, dimSize);
2342}
2343
2344OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
2345 auto const inputs = op->getInput();
2346
2347 if (inputs.empty())
2348 return {};
2349
2350 SmallVector<APInt> concatDims;
2351 concatDims.reserve(/*max elem*/ 64);
2352 for (auto const &v : inputs) {
2353 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2354 if (!vConstShape)
2355 return {};
2356
2357 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2358 assert(vAttr);
2359
2360 auto const vAttrVals = vAttr.getValues<APInt>();
2361 for (auto const &v : vAttrVals) {
2362 concatDims.push_back(v);
2363 }
2364 }
2365
2366 auto *ctx = op->getContext();
2367 assert(ctx != nullptr && "ctx is nullptr");
2368 auto const rankedTy = RankedTensorType::get(
2369 {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2370
2371 return DenseElementsAttr::get(rankedTy, concatDims);
2372}
2373
2374OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
2375 auto const input1 = op->getInput();
2376 auto const input2 = op->getStart();
2377 auto const input3 = op->getSize();
2378
2379 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2380
2381 if (!input1ConstShape)
2382 return {};
2383
2384 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2385 if (!input1Attr)
2386 return {};
2387
2388 auto const input1Vals = input1Attr.getValues<APInt>();
2389 auto const totalInput1 = input1Vals.size();
2390
2391 auto const start = getSingleI64From1ElementTensor(input2);
2392 auto const size = getSingleI64From1ElementTensor(input3);
2393
2394 if (failed(start) || failed(size))
2395 return {};
2396
2397 auto const startV = static_cast<int32_t>(start.value());
2398 auto const sizeV = static_cast<int32_t>(size.value());
2399
2400 if ((sizeV <= 0) || (startV < 0) ||
2401 (static_cast<size_t>(startV + sizeV) > totalInput1))
2402 return {};
2403
2404 SmallVector<APInt> sliceOfInput;
2405 sliceOfInput.reserve(totalInput1);
2406
2407 for (auto i = startV; i < (startV + sizeV); i++) {
2408 sliceOfInput.push_back(input1Vals[i]);
2409 }
2410
2411 auto *ctx = op->getContext();
2412 assert(ctx != nullptr && "ctx is nullptr");
2413
2414 auto const rankedTy = RankedTensorType::get(
2415 {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2416
2417 return DenseElementsAttr::get(rankedTy, sliceOfInput);
2418}
2419
2420OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2422}
2423
2424OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2426}
2427
2428OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2430}
2431
2432OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2433 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2434}
2435
2436OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2437 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2438}
2439
2440OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2442}
2443
2444OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2446}
2447
2448OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2450}
2451
2452OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2454}
2455
2456OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2458}
2459
2460OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2462}
2463
2464OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2465 return concatShapeFold(this);
2466}
2467
2468OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
2469 return sliceShapeFold(this);
2470}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:259
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
TosaLevel getTosaLevelFromEnum(const Level level)
Definition TargetEnv.cpp:15
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:106
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
TargetEnvAttr lookupTargetEnv(Operation *op)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
int32_t MAX_TENSOR_LIST_SIZE
Definition TargetEnv.h:30