MLIR 22.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/Builders.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
39#include <optional>
40
41using namespace mlir;
42using namespace mlir::tensor;
43
44/// Materialize a single constant operation from a given attribute value with
45/// the desired resultant type.
46Operation *TensorDialect::materializeConstant(OpBuilder &builder,
47 Attribute value, Type type,
48 Location loc) {
49 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 return op;
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
54 return nullptr;
55}
56
58 int64_t dim) {
59 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62
63 return builder.getIndexAttr(tensorType.getDimSize(dim));
64}
65
67 Location loc, Value value) {
68 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
71 result.push_back(getMixedSize(builder, loc, value, i));
72 return result;
73}
74
76 OpResult opResult) {
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78 assert(tensorType && "expected tensor type");
79
80 // If the op has a destination, it implements DestinationStyleOpInterface and
81 // we can query the destination operand from that interface.
82 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83 if (destOp)
84 return destOp.getTiedOpOperand(opResult)->get();
85
86 // Otherwise, create a new destination tensor with the same shape.
88 b.setInsertionPoint(opResult.getDefiningOp());
89
90 // Compute sizes.
92 if (!tensorType.hasStaticShape()) {
93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94 ReifiedRankedShapedTypeDims reifiedShapes;
95 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96 return failure();
97 mixedSizes = reifiedShapes[opResult.getResultNumber()];
98 } else {
99 // Static shape: Take static sizes directly.
100 for (int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(b.getIndexAttr(sz));
102 }
103
104 // Create empty tensor.
105 Value emptyTensor =
106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
107 return emptyTensor;
108}
109
111 Operation *op,
113 for (OpResult opResult : op->getResults()) {
114 if (llvm::isa<TensorType>(opResult.getType())) {
115 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116 if (failed(destination))
117 return failure();
118 result.push_back(*destination);
119 }
120 }
121 return success();
122}
123
125 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
129 return false;
130 }
131 return tp1 == tp2; // default implementation
132}
133
134/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135/// rank-extending tensor.insert_slice op.
136static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137 ArrayRef<OpFoldResult> mixedSizes) {
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
140
141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
143 // Rank-reduced dims must have a static unit dimension.
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
147
148 if (shapePos < 0) {
149 // There are no more dims in the reduced shape. All remaining sizes must
150 // be rank-reduced dims.
151 assert(isStaticUnitSize && "expected unit dim");
152 droppedDims.set(idx);
153 continue;
154 }
155
156 // Dim is preserved if the size is not a static 1.
157 if (!isStaticUnitSize) {
158 --shapePos;
159 continue;
160 }
161
162 // Dim is preserved if the reduced shape dim is also 1.
163 if (reducedShape[shapePos] == 1) {
164 --shapePos;
165 continue;
166 }
167
168 // Otherwise: Dim is dropped.
169 droppedDims.set(idx);
170 }
171
172 assert(shapePos < 0 && "dimension mismatch");
173 return droppedDims;
174}
175
176/// Given a ranked tensor type and a range of values that defines its dynamic
177/// dimension sizes, turn all dynamic sizes that have a constant value into
178/// static dimension sizes.
179static RankedTensorType
180foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181 SmallVector<Value> &foldedDynamicSizes) {
182 SmallVector<int64_t> staticShape(type.getShape());
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
185
186 // Compute new static and dynamic sizes.
187 unsigned ctr = 0;
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
191 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192 if (cst.has_value()) {
193 // Dynamic size must be non-negative.
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
196 continue;
197 }
198 staticShape[i] = *cst;
199 } else {
200 foldedDynamicSizes.push_back(dynamicSize);
201 }
202 }
203 }
204
205 return RankedTensorType::get(staticShape, type.getElementType(),
206 type.getEncoding());
207}
208
209//===----------------------------------------------------------------------===//
210// BitcastOp
211//===----------------------------------------------------------------------===//
212
213bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214 if (inputs.size() != 1 || outputs.size() != 1)
215 return false;
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
219 if (!aT || !bT)
220 return false;
221
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223 return false;
224
225 return succeeded(verifyCompatibleShape(aT, bT));
226}
227
228namespace {
229
230/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231/// operation.
232struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter) const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
240 return failure();
241
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
245 return success();
246 }
247};
248
249} // namespace
250
251void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252 MLIRContext *context) {
253 results.add<ChainedTensorBitcast>(context);
254}
255
256//===----------------------------------------------------------------------===//
257// CastOp
258//===----------------------------------------------------------------------===//
259
260void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261 setNameFn(getResult(), "cast");
262}
263
264/// Returns true if `target` is a ranked tensor type that preserves static
265/// information available in the `source` ranked tensor type.
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269
270 // Requires RankedTensorType.
271 if (!sourceType || !targetType)
272 return false;
273
274 // Requires same elemental type.
275 if (sourceType.getElementType() != targetType.getElementType())
276 return false;
277
278 // Requires same rank.
279 if (sourceType.getRank() != targetType.getRank())
280 return false;
281
282 // Requires same encoding.
283 if (sourceType.getEncoding() != targetType.getEncoding())
284 return false;
285
286 // If cast is towards more static sizes along any dimension, don't fold.
287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
290 return false;
291 }
292
293 return true;
294}
295
296/// Determines whether tensor::CastOp casts to a more dynamic version of the
297/// source tensor. This is useful to fold a tensor.cast into a consuming op and
298/// implement canonicalization patterns for ops in different dialects that may
299/// consume the results of tensor.cast operations. Such foldable tensor.cast
300/// operations are typically inserted as `slice` ops and are canonicalized,
301/// to preserve the type compatibility of their uses.
302///
303/// Returns true when all conditions are met:
304/// 1. source and result are ranked tensors with same element type and rank.
305/// 2. the tensor type has more static information than the result
306///
307/// Example:
308/// ```mlir
309/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310/// %2 = consumer %1 ... : tensor<?x?xf32> ...
311/// ```
312///
313/// folds into:
314///
315/// ```mlir
316/// %2 = consumer %0 ... : tensor<8x16xf32> ...
317/// ```
319 if (!castOp)
320 return false;
321
322 // Can fold if the source of cast has at least as much static information as
323 // its results.
324 return preservesStaticInformation(castOp.getType(),
325 castOp.getSource().getType());
326}
327
328/// Determines whether the tensor::CastOp casts to a more static version of the
329/// source tensor. This is useful to fold into a producing op and implement
330/// canonicalization patterns with the `tensor.cast` op as the root, but
331/// producer being from different dialects. Returns true when all conditions are
332/// met:
333/// 1. source and result and ranked tensors with same element type and rank.
334/// 2. the result type has more static information than the source.
335///
336/// Example:
337/// ```mlir
338/// %1 = producer ... : tensor<?x?xf32>
339/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340/// ```
341///
342/// can be canonicalized to :
343///
344/// ```mlir
345/// %2 = producer ... : tensor<8x16xf32>
346/// ```
347/// Not all ops might be canonicalizable this way, but for those that can be,
348/// this method provides a check that it is worth doing the canonicalization.
350 if (!castOp)
351 return false;
352 return preservesStaticInformation(castOp.getSource().getType(),
353 castOp.getType());
354}
355
357 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
358 if (llvm::isa<BlockArgument>(opOperand.get()))
359 return false;
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
362 });
363}
364
366 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
367 SmallVector<Value> newOperands;
368 newOperands.reserve(op->getNumOperands());
369
370 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
371
372 // Assumes that the result has dpsInits followed by nonDpsInits.
373 int64_t dpsInitIdx = 0;
374 for (OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
376 bool fold = canFoldIntoConsumerOp(tensorCastOp);
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
381 }
382 return newOperands;
383}
384
385/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
386/// that can be folded.
388 bool folded = false;
389 for (OpOperand &operand : op->getOpOperands()) {
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
391 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
392 operand.set(castOp.getOperand());
393 folded = true;
394 }
395 }
396 return success(folded);
397}
398
399bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
400 if (inputs.size() != 1 || outputs.size() != 1)
401 return false;
402 Type a = inputs.front(), b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(b);
405 if (!aT || !bT)
406 return false;
407
408 if (aT.getElementType() != bT.getElementType())
409 return false;
410
411 return succeeded(verifyCompatibleShape(aT, bT));
412}
413
414/// Compute a TensorType that has the joined shape knowledge of the two
415/// given TensorTypes. The element types need to match.
417 assert(one.getElementType() == two.getElementType());
418
419 if (!one.hasRank())
420 return two;
421 if (!two.hasRank())
422 return one;
423
424 int64_t rank = one.getRank();
425 if (rank != two.getRank())
426 return {};
427
429 join.reserve(rank);
430 for (int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
433 continue;
434 }
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
437 continue;
438 }
439 if (one.getDimSize(i) != two.getDimSize(i))
440 return {};
441 join.push_back(one.getDimSize(i));
442 }
443 return RankedTensorType::get(join, one.getElementType());
444}
445
446namespace {
447
448/// Replaces chains of two tensor.cast operations by a single tensor.cast
449/// operation if doing so does not remove runtime constraints.
450struct ChainedTensorCast : public OpRewritePattern<CastOp> {
451 using OpRewritePattern<CastOp>::OpRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter) const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
456
457 if (!tensorCastOperand)
458 return failure();
459
460 auto sourceType =
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
464
465 // We can remove the intermediate cast if joining all three produces the
466 // same result as just joining the source and result shapes.
467 auto firstJoin =
468 joinShapes(joinShapes(sourceType, intermediateType), resultType);
469
470 // The join might not exist if the cast sequence would fail at runtime.
471 if (!firstJoin)
472 return failure();
473
474 // The newJoin always exists if the above join exists, it might just contain
475 // less information. If so, we cannot drop the intermediate cast, as doing
476 // so would remove runtime checks.
477 auto newJoin = joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
479 return failure();
480
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
483 return success();
484 }
485};
486
487/// Fold tensor.cast into tesor.extract_slice producer.
488/// Example:
489/// ```
490/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
491/// tensor<128x512xf32> to tensor<?x512xf32>
492/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
493/// ```
494/// ->
495/// ```
496/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
497/// tensor<128x512xf32> to tensor<16x512xf32>
498/// ```
499struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
500 using OpRewritePattern<CastOp>::OpRewritePattern;
501
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter) const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
506
507 // Cannot fold cast to unranked tensor.
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
511 return failure();
512
513 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
516 .getShape())
517 return failure();
518
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
520 auto dimMask = computeRankReductionMask(
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
522 size_t dimIndex = 0;
523 for (size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
525 continue;
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
528 continue;
529 sizes[i] = rewriter.getIndexAttr(dim);
530 }
531
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
536 return success();
537 }
538};
539
540} // namespace
541
542void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
543 MLIRContext *context) {
544 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
545}
546
547//===----------------------------------------------------------------------===//
548// ConcatOp
549//===----------------------------------------------------------------------===//
550
551RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553 auto tensorTypes =
554 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
555 return llvm::cast<RankedTensorType>(type);
556 }));
557 int64_t concatRank = tensorTypes[0].getRank();
558
559 // The concatenation dim must be in the range [0, rank).
560 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
561
562 SmallVector<int64_t> sizes(concatRank);
563 for (int64_t i = 0, e = concatRank; i < e; ++i) {
564 if (i == dim)
565 continue;
566 SaturatedInteger size;
567 for (auto tensorType : tensorTypes)
568 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
569 sizes[i] = size.asInteger();
570 }
571 auto concatSize = SaturatedInteger::wrap(0);
572 for (auto tensorType : tensorTypes)
573 concatSize =
574 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
577}
578
579void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
580 ValueRange inputs) {
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.getTypes());
583 assert(succeeded(resultType) && "failed to infer concatenation result type");
584 build(builder, result, *resultType, dim, inputs);
585}
586
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
589 return emitOpError("requires at least one input");
590
592 for (auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
594
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
599 }))
600 return emitOpError("rank of concatenated inputs must match result rank");
601
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
605 }))
606 return emitOpError("inputs and result element type must match");
607
608 int64_t dim = getDim();
609 if (dim >= resultRank)
610 return emitOpError("concatenation dim must be less than the tensor rank");
611
612 SmallVector<int64_t> sizes(resultRank);
613 for (int64_t i = 0, e = resultRank; i < e; ++i) {
614 if (i == dim)
615 continue;
616 SaturatedInteger size;
617 for (auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
619 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
620 if (failed(maybeSize))
621 return emitOpError("static concatenation size mismatch along ")
622 << "non-concatenated dimension " << i;
623 size = *maybeSize;
624 }
625 sizes[i] = size.asInteger();
626 }
627 auto concatSize = SaturatedInteger::wrap(0);
628 for (auto tensorType : inputTypes)
629 concatSize =
630 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
633 RankedTensorType::get(sizes, inputTypes[0].getElementType());
634
635 for (auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
640 return emitOpError("result type ")
641 << resultType << "does not match inferred shape "
642 << inferredResultType << " static sizes";
643 }
644
645 return success();
646}
647
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
651
653 inputShapes.reserve(numInputs);
654 SmallVector<OpFoldResult> concatOffsets;
655 concatOffsets.reserve(numInputs);
656 SmallVector<OpFoldResult> outputShape;
657
658 AffineExpr addExpr =
659 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
660 OpFoldResult zero = builder.getIndexAttr(0);
661 Location loc = getLoc();
662 for (auto [index, input] : llvm::enumerate(getInputs())) {
663 SmallVector<OpFoldResult> inputShape =
664 tensor::getMixedSizes(builder, input.getLoc(), input);
665 if (index == 0) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
668 } else {
669 concatOffsets.push_back(outputShape[concatDim]);
670 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
673 }
674 inputShapes.emplace_back(std::move(inputShape));
675 }
676
677 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
679
680 int64_t rank = getType().getRank();
681 OpFoldResult one = builder.getIndexAttr(1);
682 SmallVector<OpFoldResult> strides(rank, one);
683 SmallVector<OpFoldResult> offsets(rank, zero);
684 for (auto [index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[index];
686 auto insertSlice = tensor::InsertSliceOp::create(
687 builder, loc, input, replacement, offsets, inputShapes[index], strides);
688 replacement = insertSlice.getResult();
689 }
690 if (replacement.getType() != getType()) {
691 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
692 }
694}
695
696LogicalResult
697ConcatOp::reifyResultShapes(OpBuilder &builder,
698 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
699 ValueRange inputs = getInputs();
700 int64_t dim = getDim();
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
702
703 Value init = inputs[0];
704 int64_t rank = getType().getRank();
705
706 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
707
708 // Pre-populate the result sizes with as much static information as possible
709 // from the given result type, as well as the inferred result type, otherwise
710 // use the dim sizes from the first input.
711 for (int64_t i = 0; i < rank; ++i) {
712 if (i == dim)
713 continue;
714 if (!getType().isDynamicDim(i)) {
715 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
716 } else if (!inferredResultType.isDynamicDim(i)) {
717 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
718 builder, getLoc(),
719 builder.getIndexAttr(inferredResultType.getDimSize(i)));
720 } else {
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
723 }
724 }
725
726 if (getType().isDynamicDim(dim)) {
727 // Take the sum of the input sizes along the concatenated dim.
728 AffineExpr sum = builder.getAffineDimExpr(0);
730 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
731 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732 sum = sum + builder.getAffineDimExpr(idx + 1);
733 sizes.push_back(
734 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
735 }
736 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
737 builder, getLoc(),
738 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
739 } else {
740 // If the result shape is static along the concatenated dim, use the static
741 // shape.
742 reifiedReturnShapes[0][dim] =
743 builder.getIndexAttr(getType().getDimSize(dim));
744 }
745 return success();
746}
747
748void ConcatOp::getAsmResultNames(
749 function_ref<void(Value, StringRef)> setNameFn) {
750 setNameFn(getResult(), "concat");
751}
752
753OpFoldResult ConcatOp::fold(FoldAdaptor) {
754 ValueRange inputs = getInputs();
755 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
756 return inputs[0];
757 return {};
758}
759
760namespace {
761/// Fold a concat op with a single input to a cast.
762struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
764
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter) const override {
767 if (concatOp.getInputs().size() != 1)
768 return failure();
769 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
770 concatOp.getInputs()[0]);
771 return success();
772 }
773};
774
775/// Propagate static shapes into the operands of a `tensor.concat`.
776///
777/// `tensor.concat` requires every operand to match on all dimensions except the
778/// concatenation dimension. If one operand is already static in those
779/// dimensions, the other operands may safely be refined to that same static
780/// shape.
781///
782/// Example:
783///
784/// ```mlir
785/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
786/// tensor<?x12xi32>
787/// ```
788/// ->
789/// ```mlir
790/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
791/// %2 = tensor.concat dim(0) %0, %cast :
792/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
793/// ```
794struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
796
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter) const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802
803 // Find operands for which a more static shape can be inferred.
804 LogicalResult matched = failure();
805 // Inferred operand shapes are identical in every dimension except the
806 // concatenation dimension.
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
810 // Compute inferred type for operand.
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
815
816 // Check if inferred type is more static.
817 if (!preservesStaticInformation(inferredOperandType, operandType)) {
818 matched = success();
819
820 // Use refined operand type and create cast from original operand.
821 auto castOp =
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
824 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
826 });
827 }
828 }
829
830 return matched;
831 }
832};
833
834// Ensure `tensor.concat`'s result type is at least as static as can be inferred
835// from its operand types.
836///
837/// Example:
838/// ```mlir
839/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
840/// tensor<?x?xi32>
841/// ```
842/// ->
843/// ```mlir
844/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
845/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
846/// tensor<?x?xi32>
847/// ```
848struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
850
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter) const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
856
857 // The result type should be at least as static as inferred result type.
858 if (preservesStaticInformation(inferredResultType,
859 concatOp.getResultType())) {
860 return failure();
861 }
862
863 auto newConcatOp =
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
866 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
867 newConcatOp);
868
869 return success();
870 }
871};
872} // namespace
873
874void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
875 MLIRContext *context) {
876 results
877 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
878 context);
879}
880
881//===----------------------------------------------------------------------===//
882// DimOp
883//===----------------------------------------------------------------------===//
884
885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886 setNameFn(getResult(), "dim");
887}
888
889void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890 int64_t index) {
891 auto loc = result.location;
892 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893 build(builder, result, source, indexValue);
894}
895
896std::optional<int64_t> DimOp::getConstantIndex() {
898}
899
900Speculation::Speculatability DimOp::getSpeculatability() {
901 auto constantIndex = getConstantIndex();
902 if (!constantIndex)
904
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
906 if (!rankedSourceType)
908
909 if (rankedSourceType.getRank() <= constantIndex)
911
913}
914
915void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916 SetIntLatticeFn setResultRange) {
917 setResultRange(getResult(),
918 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919}
920
921OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
922 // All forms of folding require a known index.
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
924 if (!index)
925 return {};
926
927 // Folding for unranked types (UnrankedTensorType) is not supported.
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
929 if (!tensorType)
930 return {};
931
932 // Out of bound indices produce undefined behavior but are still valid IR.
933 // Don't choke on them.
934 int64_t indexVal = index.getInt();
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
936 return {};
937
938 // Fold if the shape extent along the given index is known.
939 if (!tensorType.isDynamicDim(index.getInt())) {
940 Builder builder(getContext());
941 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
942 }
943
944 Operation *definingOp = getSource().getDefiningOp();
945
946 // Fold dim to the operand of tensor.generate.
947 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
948 auto resultType =
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950 // The case where the type encodes the size of the dimension is handled
951 // above.
952 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
953
954 // Find the operand of the fromElements that corresponds to this index.
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (auto dim : resultType.getShape().take_front(index.getInt()))
957 if (ShapedType::isDynamic(dim))
958 dynExtents++;
959
960 return Value{*dynExtents};
961 }
962
963 // The size at the given index is now known to be a dynamic size.
964 unsigned unsignedIndex = index.getValue().getZExtValue();
965
966 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
968 // `resolve-shaped-type-result-dims` pass.
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
972 }
973 }
974
975 // dim(cast) -> dim
976 if (succeeded(foldTensorCast(*this)))
977 return getResult();
978
979 return {};
980}
981
982namespace {
983/// Fold dim of a cast into the dim of the source of the tensor cast.
984struct DimOfCastOp : public OpRewritePattern<DimOp> {
985 using OpRewritePattern<DimOp>::OpRewritePattern;
986
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter) const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990 if (!castOp)
991 return failure();
992 Value newSource = castOp.getOperand();
993 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
994 return success();
995 }
996};
997
998/// Fold dim of a destination passing style op into the dim of the corresponding
999/// init.
1000struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter) const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1007 if (!destOp)
1008 return failure();
1009
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012
1013 rewriter.modifyOpInPlace(
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1015 return success();
1016 }
1017};
1018
1019/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1020/// operand.
1021struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1023
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter) const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1027
1028 if (!reshape)
1029 return failure();
1030
1031 // Since tensors are immutable we don't need to worry about where to place
1032 // the extract call
1033 rewriter.setInsertionPointAfter(dim);
1034 Location loc = dim.getLoc();
1035 Value extract =
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.getType() != dim.getType())
1038 extract =
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1040 rewriter.replaceOp(dim, extract);
1041 return success();
1042 }
1043};
1044} // namespace
1045
1046void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1047 MLIRContext *context) {
1048 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// EmptyOp
1053//===----------------------------------------------------------------------===//
1054
1055void EmptyOp::build(OpBuilder &builder, OperationState &result,
1056 ArrayRef<int64_t> staticShape, Type elementType,
1057 Attribute encoding) {
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1060 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1061}
1062
1063void EmptyOp::build(OpBuilder &builder, OperationState &result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder, result, tensorType, dynamicSizes);
1068}
1069
1070void EmptyOp::build(OpBuilder &builder, OperationState &result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1075 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1076 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1077}
1078
1079LogicalResult EmptyOp::verify() {
1080 if (getType().getNumDynamicDims() != getDynamicSizes().size())
1081 return emitOpError("incorrect number of dynamic sizes, has ")
1082 << getDynamicSizes().size() << ", expected "
1083 << getType().getNumDynamicDims();
1084 return success();
1085}
1086
1087LogicalResult
1088EmptyOp::reifyResultShapes(OpBuilder &builder,
1089 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1090 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1091 unsigned ctr = 0;
1092 for (int64_t i = 0; i < getType().getRank(); ++i) {
1093 if (getType().isDynamicDim(i)) {
1094 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1095 } else {
1096 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1097 }
1098 }
1099 return success();
1100}
1101
1102Value EmptyOp::getDynamicSize(unsigned idx) {
1103 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1104 unsigned ctr = 0;
1105 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1106 if (getType().isDynamicDim(i))
1107 ++ctr;
1108 return getDynamicSizes()[ctr];
1109}
1110
1111SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1112 SmallVector<OpFoldResult> result;
1113 unsigned ctr = 0;
1114 OpBuilder b(getContext());
1115 for (int64_t i = 0; i < getType().getRank(); ++i) {
1116 if (getType().isDynamicDim(i)) {
1117 result.push_back(getDynamicSizes()[ctr++]);
1118 } else {
1119 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1120 }
1121 }
1122 return result;
1123}
1124
1125namespace {
1126/// Change the type of the result of a `tensor.empty` by making the result
1127/// type statically sized along dimensions that in the original operation were
1128/// defined as dynamic, but the size was defined using a `constant` op. For
1129/// example
1130///
1131/// %c5 = arith.constant 5: index
1132/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1133///
1134/// to
1135///
1136/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1137struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1138 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1139
1140 LogicalResult matchAndRewrite(EmptyOp op,
1141 PatternRewriter &rewriter) const override {
1142 SmallVector<Value> foldedDynamicSizes;
1143 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1144 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1145
1146 // Stop here if no dynamic size was promoted to static.
1147 if (foldedTensorType == op.getType())
1148 return failure();
1149
1150 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1151 foldedDynamicSizes);
1152 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1153 return success();
1154 }
1155};
1156
1157struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1158 using OpRewritePattern<DimOp>::OpRewritePattern;
1159
1160 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1161 PatternRewriter &rewriter) const override {
1162 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1163 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1164 if (!emptyTensorOp || !maybeConstantIndex)
1165 return failure();
1166 auto emptyTensorType = emptyTensorOp.getType();
1167 if (*maybeConstantIndex < 0 ||
1168 *maybeConstantIndex >= emptyTensorType.getRank() ||
1169 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1170 return failure();
1171 rewriter.replaceOp(dimOp,
1172 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1173 return success();
1174 }
1175};
1176
1177/// Canonicalize
1178///
1179/// ```mlir
1180/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1181/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1182/// ```
1183///
1184/// into
1185///
1186/// ```mlir
1187/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1188/// ```
1189///
1190/// This assumes the input program is correct in terms of its shape. So it is
1191/// safe to assume that `%d0` is in fact 4.
1192struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1193 using OpRewritePattern<CastOp>::OpRewritePattern;
1194
1195 LogicalResult matchAndRewrite(CastOp castOp,
1196 PatternRewriter &rewriter) const override {
1197 if (!canFoldIntoProducerOp(castOp))
1198 return failure();
1199 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1200 if (!producer)
1201 return failure();
1202
1203 auto resultType =
1204 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1205 ArrayRef<int64_t> resultShape = resultType.getShape();
1206 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1207 SmallVector<OpFoldResult> newMixedSizes;
1208 newMixedSizes.reserve(currMixedSizes.size());
1209 assert(resultShape.size() == currMixedSizes.size() &&
1210 "mismatch in result shape and sizes of empty op");
1211 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1212 int64_t newDim = std::get<0>(it);
1213 OpFoldResult currDim = std::get<1>(it);
1214 // Case 1: The empty tensor dim is static. Check that the tensor cast
1215 // result dim matches.
1216 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1217 if (ShapedType::isDynamic(newDim) ||
1218 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1219 // Something is off, the cast result shape cannot be more dynamic
1220 // than the empty tensor result shape (enforced by
1221 // `canFoldIntoProducer`). Abort for now.
1222 return rewriter.notifyMatchFailure(
1223 producer, "mismatch in static value of shape of empty tensor "
1224 "result and cast result");
1225 }
1226 newMixedSizes.push_back(attr);
1227 continue;
1228 }
1229
1230 // Case 2 : The tensor cast shape is static, but empty tensor result
1231 // shape is dynamic.
1232 if (ShapedType::isStatic(newDim)) {
1233 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1234 continue;
1235 }
1236
1237 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1238 // shape is dynamic. Use the dynamic value from the empty tensor op.
1239 newMixedSizes.push_back(currDim);
1240 }
1241
1242 // TODO: Do not drop tensor encoding.
1243 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1244 resultType.getElementType());
1245 return success();
1246 }
1247};
1248
1249} // namespace
1250
1251void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1252 MLIRContext *context) {
1253 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1254 ReplaceEmptyTensorStaticShapeDims>(context);
1255}
1256
1257//===----------------------------------------------------------------------===//
1258// ExtractOp
1259//===----------------------------------------------------------------------===//
1260
1261namespace {
1262
1263/// Canonicalizes the pattern of the form
1264///
1265/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1266/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1267///
1268/// to
1269///
1270/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1271struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1272 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1273
1274 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1275 PatternRewriter &rewriter) const final {
1276 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1277 if (!tensorCast)
1278 return failure();
1279 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1280 return failure();
1281 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1282 extract, tensorCast.getSource(), extract.getIndices());
1283 return success();
1284 }
1285};
1286
1287/// Canonicalizes the pattern of the form
1288///
1289/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1290/// tensor<12xf64>
1291/// %extracted_element = tensor.extract %val[%c10] :
1292/// tensor<12xf64>
1293///
1294/// to
1295///
1296/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1297struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1298 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1299
1300 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1301 PatternRewriter &rewriter) const final {
1302 auto collapseOp =
1303 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1304 if (!collapseOp)
1305 return failure();
1306 if (!collapseOp.getSrcType().hasStaticShape())
1307 return failure();
1308
1309 auto sourceSizes = collapseOp.getSrcType().getShape();
1310
1311 SmallVector<Value> indices(extractOp.getIndices().begin(),
1312 extractOp.getIndices().end());
1313 SmallVector<Value> sourceIndices;
1314 for (auto [index, group] :
1315 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1316 assert(!group.empty() && "association indices groups cannot be empty");
1317 auto groupSize = group.size();
1318
1319 if (groupSize == 1) {
1320 sourceIndices.push_back(index);
1321 continue;
1322 }
1323
1324 SmallVector<int64_t> basis =
1325 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1326 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1327 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1328 llvm::append_range(sourceIndices, delinearize.getResults());
1329 }
1330 if (collapseOp.getReassociationIndices().empty()) {
1331 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1332 int64_t srcRank =
1333 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1334 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1335 rewriter, extractOp.getLoc(), zeroAffineMap,
1336 ArrayRef<OpFoldResult>{});
1337 for (int64_t i = 0; i < srcRank; i++) {
1338 sourceIndices.push_back(
1339 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1340 }
1341 }
1342
1343 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1344 extractOp, collapseOp.getSrc(), sourceIndices);
1345 return success();
1346 }
1347};
1348
1349} // namespace
1350
1351void ExtractOp::getAsmResultNames(
1352 function_ref<void(Value, StringRef)> setNameFn) {
1353 setNameFn(getResult(), "extracted");
1354}
1355
1356LogicalResult ExtractOp::verify() {
1357 // Verify the # indices match if we have a ranked type.
1358 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1359 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1360 return emitOpError("incorrect number of indices for extract_element");
1361 return success();
1362}
1363
1364/// If we have an ExtractOp consuming an InsertOp with the same
1365/// indices, we can return the InsertOp's scalar directly.
1366// TODO: This only checks the immediate producer; extend to go up the
1367// insert/extract chain if the slices are disjoint.
1368static Value foldExtractAfterInsert(ExtractOp extractOp) {
1369 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1370
1371 auto isSame = [](Value a, Value b) {
1373 };
1374 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1375 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1376 return insertOp.getScalar();
1377
1378 return {};
1379}
1380
1381OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1382 if (Attribute tensor = adaptor.getTensor()) {
1383 // If this is a splat elements attribute, simply return the value.
1384 // All of the elements of a splat attribute are the same.
1385 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1386 return splatTensor.getSplatValue<Attribute>();
1387
1388 // If this is a dense resource elements attribute, return.
1389 if (isa<DenseResourceElementsAttr>(tensor))
1390 return {};
1391 }
1392
1393 // Collect the constant indices into the tensor.
1394 SmallVector<uint64_t, 8> indices;
1395 for (Attribute indice : adaptor.getIndices()) {
1396 if (!indice || !llvm::isa<IntegerAttr>(indice))
1397 return {};
1398 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1399 }
1400
1401 // Fold extract(from_elements(...)).
1402 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1403 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1404 auto rank = tensorType.getRank();
1405 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1406 "rank mismatch");
1407 int flatIndex = 0;
1408 int stride = 1;
1409 for (int i = rank - 1; i >= 0; --i) {
1410 flatIndex += indices[i] * stride;
1411 stride *= tensorType.getDimSize(i);
1412 }
1413 // Prevent out of bounds accesses. This can happen in invalid code that
1414 // will never execute.
1415 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1416 flatIndex < 0)
1417 return {};
1418 return fromElementsOp.getElements()[flatIndex];
1419 }
1420
1421 // If this is an elements attribute, query the value at the given indices.
1422 if (Attribute tensor = adaptor.getTensor()) {
1423 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1424 if (elementsAttr && elementsAttr.isValidIndex(indices))
1425 return elementsAttr.getValues<Attribute>()[indices];
1426 }
1427
1428 if (Value result = foldExtractAfterInsert(*this))
1429 return result;
1430
1431 return {};
1432}
1433
1434void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1435 MLIRContext *context) {
1436 results.add<ExtractFromTensorCast>(context);
1437}
1438
1441 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1442}
1443
1444//===----------------------------------------------------------------------===//
1445// FromElementsOp
1446//===----------------------------------------------------------------------===//
1447
1448void FromElementsOp::getAsmResultNames(
1449 function_ref<void(Value, StringRef)> setNameFn) {
1450 setNameFn(getResult(), "from_elements");
1451}
1452
1453void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1454 ValueRange elements) {
1455 assert(!elements.empty() && "expected at least one element");
1456 Type resultType = RankedTensorType::get(
1457 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1458 build(builder, result, resultType, elements);
1459}
1460
1461OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1462 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1463 return DenseElementsAttr::get(getType(), adaptor.getElements());
1464 return {};
1465}
1466
1467namespace {
1468
1469// Pushes the index_casts that occur before extractions to after the extract.
1470// This minimizes type conversion in some cases and enables the extract
1471// canonicalizer. This changes:
1472//
1473// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1474// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1475//
1476// to the following:
1477//
1478// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1479// %cast = arith.index_cast %extract : i32 to index
1480//
1481// to just %element.
1482//
1483// Consider expanding this to a template and handle all tensor cast
1484// operations.
1485struct ExtractElementFromIndexCast
1486 : public OpRewritePattern<tensor::ExtractOp> {
1487 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1488
1489 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1490 PatternRewriter &rewriter) const final {
1491 Location loc = extract.getLoc();
1492 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1493 if (!indexCast)
1494 return failure();
1495
1496 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1497
1498 auto newExtract = tensor::ExtractOp::create(
1499 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1500
1501 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1502 newExtract);
1503
1504 return success();
1505 }
1506};
1507
1508} // namespace
1509
1510void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1511 MLIRContext *context) {
1512 results.add<ExtractElementFromIndexCast>(context);
1513}
1514
1515//===----------------------------------------------------------------------===//
1516// GatherOp
1517//===----------------------------------------------------------------------===//
1518
1519void GatherOp::getAsmResultNames(
1520 function_ref<void(Value, StringRef)> setNameFn) {
1521 setNameFn(getResult(), "gather");
1522}
1523
1524/// Return the inferred result type for a gatherOp where:
1525/// - sourceType is the type of the source tensor gathered from
1526/// - indicesType is the type of the indices used to gather
1527/// - gatherDims are the dims along which the gather occurs.
1528/// Return a full rank or ranked-reduced variant of the type depending on
1529/// the value of rankReduced.
1530///
1531/// The leading dimensions of the index tensor give the result tensor its
1532/// leading dimensions.
1533/// The trailing dimensions of the result tensor are obtained from the source
1534/// tensor by setting the dimensions specified in gather_dims to `1` (if
1535/// rankedReduced is false), or skipping them (otherwise).
1536RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1537 RankedTensorType indicesType,
1538 ArrayRef<int64_t> gatherDims,
1539 bool rankReduced) {
1540 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1541 resultShape.reserve(resultShape.size() + sourceType.getRank());
1542 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1543 if (llvm::binary_search(gatherDims, idx)) {
1544 if (!rankReduced)
1545 resultShape.push_back(1);
1546 continue;
1547 }
1548 resultShape.push_back(sourceType.getDimSize(idx));
1549 }
1550 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1551}
1552
1553static LogicalResult
1556 StringRef gatherOrScatter, StringRef sourceOrDest) {
1557 if (dims.empty())
1558 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1559
1560 int64_t numGatherDims = dims.size();
1561 if (numGatherDims > rank)
1562 return op->emitOpError(gatherOrScatter)
1563 << "_dims overflow " << sourceOrDest << " rank";
1564 if (indices.empty() || indices.back() != numGatherDims)
1565 return op->emitOpError(gatherOrScatter)
1566 << "_dims length must match the size of last dimension of indices";
1567 for (int64_t val : dims) {
1568 if (val < 0)
1569 return op->emitOpError(gatherOrScatter)
1570 << "_dims value must be non-negative";
1571 if (val >= rank)
1572 return op->emitOpError(gatherOrScatter)
1573 << "_dims value must be smaller than " << sourceOrDest << " rank";
1574 }
1575 for (int64_t i = 1; i < numGatherDims; ++i) {
1576 if (dims[i - 1] >= dims[i])
1577 return op->emitOpError(gatherOrScatter)
1578 << "_dims values must be strictly increasing";
1579 }
1580 return success();
1581}
1582
1583LogicalResult GatherOp::verify() {
1584 int64_t sourceRank = getSourceType().getRank();
1585 ArrayRef<int64_t> gatherDims = getGatherDims();
1586 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1587 getIndicesType().getShape(), sourceRank,
1588 "gather", "source")))
1589 return failure();
1590
1591 RankedTensorType expectedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1593 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1594 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1595 if (getResultType() != expectedResultType &&
1596 getResultType() != expectedRankReducedResultType) {
1597 return emitOpError("result type "
1598 "mismatch: "
1599 "expected ")
1600 << expectedResultType << " or its rank-reduced variant "
1601 << expectedRankReducedResultType << " (got: " << getResultType()
1602 << ")";
1603 }
1604
1605 return success();
1606}
1607
1608OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1609 if (OpFoldResult reshapedSource = reshapeConstantSource(
1610 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1611 getResult().getType()))
1612 return reshapedSource;
1613 return {};
1614}
1615
1616//===----------------------------------------------------------------------===//
1617// InsertOp
1618//===----------------------------------------------------------------------===//
1619
1620void InsertOp::getAsmResultNames(
1621 function_ref<void(Value, StringRef)> setNameFn) {
1622 setNameFn(getResult(), "inserted");
1623}
1624
1625LogicalResult InsertOp::verify() {
1626 // Verify the # indices match if we have a ranked type.
1627 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1628 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1629 return emitOpError("incorrect number of indices");
1630 return success();
1631}
1632
1633OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1634 Attribute scalar = adaptor.getScalar();
1635 Attribute dest = adaptor.getDest();
1636 if (scalar && dest)
1637 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1638 if (scalar == splatDest.getSplatValue<Attribute>())
1639 return dest;
1640 return {};
1641}
1642
1643//===----------------------------------------------------------------------===//
1644// GenerateOp
1645//===----------------------------------------------------------------------===//
1646
1647void GenerateOp::getAsmResultNames(
1648 function_ref<void(Value, StringRef)> setNameFn) {
1649 setNameFn(getResult(), "generated");
1650}
1651
1652LogicalResult GenerateOp::reifyResultShapes(
1653 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1654 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1655 int idx = 0;
1656 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1657 if (getType().isDynamicDim(dim)) {
1658 reifiedReturnShapes[0][dim] = getOperand(idx++);
1659 } else {
1660 reifiedReturnShapes[0][dim] =
1661 builder.getIndexAttr(getType().getDimSize(dim));
1662 }
1663 }
1664 return success();
1665}
1666
1667LogicalResult GenerateOp::verify() {
1668 // Ensure that the tensor type has as many dynamic dimensions as are
1669 // specified by the operands.
1670 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1671 if (getNumOperands() != resultType.getNumDynamicDims())
1672 return emitError("must have as many index operands as dynamic extents "
1673 "in the result type");
1674 return success();
1675}
1676
1677LogicalResult GenerateOp::verifyRegions() {
1678 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1679 // Ensure that region arguments span the index space.
1680 if (!llvm::all_of(getBody().getArgumentTypes(),
1681 [](Type ty) { return ty.isIndex(); }))
1682 return emitError("all body arguments must be index");
1683 if (getBody().getNumArguments() != resultTy.getRank())
1684 return emitError("must have one body argument per input dimension");
1685
1686 // Ensure that the region yields an element of the right type.
1687 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1688
1689 if (yieldOp.getValue().getType() != resultTy.getElementType())
1690 return emitOpError(
1691 "body must be terminated with a `yield` operation of the tensor "
1692 "element type");
1693
1694 return success();
1695}
1696
1697void GenerateOp::build(
1698 OpBuilder &b, OperationState &result, Type resultTy,
1699 ValueRange dynamicExtents,
1700 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1701 build(b, result, resultTy, dynamicExtents);
1702
1703 // Build and populate body.
1704 OpBuilder::InsertionGuard guard(b);
1705 Region *bodyRegion = result.regions.front().get();
1706 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1707 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1708 SmallVector<Location, 2> argumentLocs(rank, result.location);
1709 Block *bodyBlock =
1710 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1711 bodyBuilder(b, result.location, bodyBlock->getArguments());
1712}
1713
1714namespace {
1715
1716/// Canonicalizes tensor.generate operations with a constant
1717/// operand into the equivalent operation with the operand expressed in the
1718/// result type, instead. We also insert a type cast to make sure that the
1719/// resulting IR is still well-typed.
1720struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1721 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1722
1723 LogicalResult matchAndRewrite(GenerateOp generateOp,
1724 PatternRewriter &rewriter) const final {
1725 SmallVector<Value> foldedDynamicSizes;
1726 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1727 generateOp.getType(), generateOp.getDynamicExtents(),
1728 foldedDynamicSizes);
1729
1730 // Stop here if no dynamic size was promoted to static.
1731 if (foldedTensorType == generateOp.getType())
1732 return failure();
1733
1734 auto loc = generateOp.getLoc();
1735 auto newOp =
1736 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1737 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1738 newOp.getBody().begin());
1739 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1740 generateOp.getType(), newOp);
1741 return success();
1742 }
1743};
1744
1745/// Canonicalizes the pattern of the form
1746///
1747/// %tensor = tensor.generate %x {
1748/// ^bb0(%arg0: index):
1749/// <computation>
1750/// yield %1 : index
1751/// } : tensor<?xindex>
1752/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1753///
1754/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1755/// tensor.generate operation has no side-effects.
1756struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1757 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1758
1759 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1760 PatternRewriter &rewriter) const final {
1761 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1762 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1763 return failure();
1764
1765 IRMapping mapping;
1766 Block *body = &tensorFromElements.getBody().front();
1767 mapping.map(body->getArguments(), extract.getIndices());
1768 for (auto &op : body->without_terminator())
1769 rewriter.clone(op, mapping);
1770
1771 auto yield = cast<YieldOp>(body->getTerminator());
1772
1773 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1774 return success();
1775 }
1776};
1777
1778} // namespace
1779
1780void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1781 MLIRContext *context) {
1782 // TODO: Move extract pattern to tensor::ExtractOp.
1783 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1784}
1785
1786//===----------------------------------------------------------------------===//
1787// RankOp
1788//===----------------------------------------------------------------------===//
1789
1790void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1791 setNameFn(getResult(), "rank");
1792}
1793
1794OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1795 // Constant fold rank when the rank of the operand is known.
1796 auto type = getOperand().getType();
1797 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1798 if (shapedType && shapedType.hasRank())
1799 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1800 return IntegerAttr();
1801}
1802
1803//===----------------------------------------------------------------------===//
1804// ReshapeOp
1805//===----------------------------------------------------------------------===//
1806
1807void ReshapeOp::getAsmResultNames(
1808 function_ref<void(Value, StringRef)> setNameFn) {
1809 setNameFn(getResult(), "reshape");
1810}
1811
1812static int64_t getNumElements(ShapedType type) {
1813 int64_t numElements = 1;
1814 for (auto dim : type.getShape())
1815 numElements *= dim;
1816 return numElements;
1817}
1818
1819LogicalResult ReshapeOp::verify() {
1820 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1821 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1822
1823 if (operandType.getElementType() != resultType.getElementType())
1824 return emitOpError("element types of source and destination tensor "
1825 "types should be the same");
1826
1827 int64_t shapeSize =
1828 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1829 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1830 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1831
1832 if (resultRankedType) {
1833 if (operandRankedType && resultRankedType.hasStaticShape() &&
1834 operandRankedType.hasStaticShape()) {
1835 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1836 return emitOpError("source and destination tensor should have the "
1837 "same number of elements");
1838 }
1839 if (ShapedType::isDynamic(shapeSize))
1840 return emitOpError("cannot use shape operand with dynamic length to "
1841 "reshape to statically-ranked tensor type");
1842 if (shapeSize != resultRankedType.getRank())
1843 return emitOpError(
1844 "length of shape operand differs from the result's tensor rank");
1845 }
1846 return success();
1847}
1848
1849OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1850 if (OpFoldResult reshapedSource = reshapeConstantSource(
1851 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1852 getResult().getType()))
1853 return reshapedSource;
1854
1855 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1856 // producer's input instead as the original tensor to reshape. This could
1857 // render such producer dead code.
1858 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1859 getSourceMutable().assign(reshapeOpProducer.getSource());
1860 return getResult();
1861 }
1862
1863 auto source = getSource();
1864 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1865 auto resultTy = dyn_cast<RankedTensorType>(getType());
1866 if (!sourceTy || !resultTy || sourceTy != resultTy)
1867 return {};
1868
1869 // If the source and result are both 0D or 1D tensors and have the same type,
1870 // the reshape has no effect, even if the tensor is dynamically shaped.
1871 if (sourceTy.getRank() <= 1)
1872 return source;
1873
1874 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1875 auto elements = fromElements.getElements();
1876 bool dynamicNoop =
1877 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1878 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1879 auto element = elements[id];
1880
1881 if (auto cst = getConstantIntValue(element)) {
1882 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1883 continue;
1884 }
1885
1886 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1887 dynamicNoop &= dimOp.getSource() == source;
1888
1889 auto cst = getConstantIntValue(dimOp.getIndex());
1890 dynamicNoop &=
1891 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1892 continue;
1893 }
1894
1895 dynamicNoop = false;
1896 break;
1897 }
1898
1899 if (dynamicNoop)
1900 return source;
1901 }
1902
1903 return {};
1904}
1905
1906//===----------------------------------------------------------------------===//
1907// Reassociative reshape ops
1908//===----------------------------------------------------------------------===//
1909
1910void CollapseShapeOp::getAsmResultNames(
1911 function_ref<void(Value, StringRef)> setNameFn) {
1912 setNameFn(getResult(), "collapsed");
1913}
1914
1915void ExpandShapeOp::getAsmResultNames(
1916 function_ref<void(Value, StringRef)> setNameFn) {
1917 setNameFn(getResult(), "expanded");
1918}
1919
1920int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1921 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1922 "invalid resultDim");
1923 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1924 if (llvm::is_contained(it.value(), resultDim))
1925 return it.index();
1926 llvm_unreachable("could not find reassociation group");
1927}
1928
1929FailureOr<SmallVector<OpFoldResult>>
1930ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1931 RankedTensorType expandedType,
1932 ArrayRef<ReassociationIndices> reassociation,
1933 ArrayRef<OpFoldResult> inputShape) {
1934 std::optional<SmallVector<OpFoldResult>> outputShape =
1935 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1936 inputShape);
1937 if (!outputShape)
1938 return failure();
1939 return *outputShape;
1940}
1941
1942SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1943 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1944}
1945
1946void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1947 Type resultType, Value src,
1948 ArrayRef<ReassociationIndices> reassociation,
1949 ArrayRef<OpFoldResult> outputShape) {
1950 auto [staticOutputShape, dynamicOutputShape] =
1951 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1952 build(builder, result, cast<RankedTensorType>(resultType), src,
1953 getReassociationIndicesAttribute(builder, reassociation),
1954 dynamicOutputShape, staticOutputShape);
1955}
1956
1957void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1958 Type resultType, Value src,
1959 ArrayRef<ReassociationIndices> reassociation) {
1960 SmallVector<OpFoldResult> inputShape =
1961 getMixedSizes(builder, result.location, src);
1962 auto tensorResultTy = cast<RankedTensorType>(resultType);
1963 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1964 builder, result.location, tensorResultTy, reassociation, inputShape);
1965 SmallVector<OpFoldResult> outputShapeOrEmpty;
1966 if (succeeded(outputShape)) {
1967 outputShapeOrEmpty = *outputShape;
1968 }
1969 build(builder, result, tensorResultTy, src, reassociation,
1970 outputShapeOrEmpty);
1971}
1972
1973SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1974 return getSymbolLessAffineMaps(getReassociationExprs());
1975}
1976SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1978 getReassociationIndices());
1979}
1980
1981SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1982 return getSymbolLessAffineMaps(getReassociationExprs());
1983}
1984SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1986 getReassociationIndices());
1987}
1988
1989RankedTensorType CollapseShapeOp::inferCollapsedType(
1990 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1991 return inferCollapsedType(
1993 type.getContext(), reassociation)));
1994}
1995
1996/// Compute the RankedTensorType obtained by applying `reassociation` to
1997/// `type`.
1998RankedTensorType
1999CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2000 ArrayRef<AffineMap> reassociation) {
2001 auto shape = type.getShape();
2002 SmallVector<int64_t, 4> newShape;
2003 newShape.reserve(reassociation.size());
2004
2005 // Use the fact that reassociation is valid to simplify the logic: only use
2006 // each map's rank.
2007 assert(isReassociationValid(reassociation) && "invalid reassociation");
2008 unsigned currentDim = 0;
2009 for (AffineMap m : reassociation) {
2010 unsigned dim = m.getNumResults();
2011 auto band = shape.slice(currentDim, dim);
2012 int64_t size = 1;
2013 if (llvm::is_contained(band, ShapedType::kDynamic))
2014 size = ShapedType::kDynamic;
2015 else
2016 for (unsigned d = 0; d < dim; ++d)
2017 size *= shape[currentDim + d];
2018 newShape.push_back(size);
2019 currentDim += dim;
2020 }
2021
2022 return RankedTensorType::get(newShape, type.getElementType());
2023}
2024
2025void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2026 ArrayRef<ReassociationIndices> reassociation,
2027 ArrayRef<NamedAttribute> attrs) {
2028 auto resultType = inferCollapsedType(
2029 llvm::cast<RankedTensorType>(src.getType()),
2031 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2032 result.addAttribute(getReassociationAttrStrName(),
2033 getReassociationIndicesAttribute(b, reassociation));
2034 build(b, result, resultType, src, attrs);
2035}
2036
2037template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2038 TensorReshapeOp, ExpandShapeOp>::value>
2039static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2040 RankedTensorType expandedType,
2041 RankedTensorType collapsedType) {
2042 if (failed(
2043 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2044 return failure();
2045
2046 auto maps = op.getReassociationMaps();
2047 RankedTensorType expectedType =
2048 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2049 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2050 return op.emitOpError("expected collapsed type to be ")
2051 << expectedType << ", but got " << collapsedType;
2052 return success();
2053}
2054
2055LogicalResult ExpandShapeOp::verify() {
2056 auto srcType = getSrcType();
2057 auto resultType = getResultType();
2058
2059 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2060 return emitOpError("expected number of static shape dims to be equal to "
2061 "the output rank (")
2062 << resultType.getRank() << ") but found "
2063 << getStaticOutputShape().size() << " inputs instead";
2064
2065 if ((int64_t)getOutputShape().size() !=
2066 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2067 return emitOpError("mismatch in dynamic dims in output_shape and "
2068 "static_output_shape: static_output_shape has ")
2069 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2070 << " dynamic dims while output_shape has " << getOutputShape().size()
2071 << " values";
2072
2073 return verifyTensorReshapeOp(*this, resultType, srcType);
2074}
2075
2076LogicalResult CollapseShapeOp::verify() {
2077 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2078}
2079
2080namespace {
2081/// Reshape of a splat constant can be replaced with a constant of the result
2082/// type.
2083template <typename TensorReshapeOp>
2084struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2085 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2086 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2087 PatternRewriter &rewriter) const override {
2088 DenseElementsAttr attr;
2089 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2090 return failure();
2091 if (!attr || !attr.isSplat())
2092 return failure();
2093 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2094 reshapeOp.getResultType(), attr.getRawData());
2095 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2096 return success();
2097 }
2098};
2099
2100// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2101template <typename TensorReshapeOp>
2102class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2103public:
2104 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2105
2106 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2107 PatternRewriter &rewriter) const override {
2108 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2109 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2110 return failure();
2111
2112 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2113 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2114 return success();
2115 }
2116};
2117
2118/// Reshape of a FromElements can be replaced with a FromElements of the
2119/// result type
2120template <typename TensorReshapeOp>
2121struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2122 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2123 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2124 PatternRewriter &rewriter) const override {
2125 auto fromElements =
2126 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2127 if (!fromElements)
2128 return failure();
2129
2130 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2131
2132 if (!shapedTy.hasStaticShape())
2133 return failure();
2134
2135 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2136 fromElements.getElements());
2137 return success();
2138 }
2139};
2140
2141// Fold CastOp into CollapseShapeOp when adding static information.
2142struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2143 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2144
2145 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2146 PatternRewriter &rewriter) const override {
2147 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2148 if (!tensor::canFoldIntoConsumerOp(castOp))
2149 return failure();
2150
2151 RankedTensorType srcType =
2152 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2153 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2154 srcType, collapseShapeOp.getReassociationMaps());
2155
2156 if (newResultType == collapseShapeOp.getResultType()) {
2157 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2158 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2159 });
2160 } else {
2161 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2162 newResultType, castOp.getSource(),
2163 collapseShapeOp.getReassociation());
2164 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2165 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2166 }
2167 return success();
2168 }
2169};
2170
2171/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2172/// matching constant output_shape operands of the expand. This makes the
2173/// `tensor.expand_shape` more static and creates a consumer cast that can be
2174/// propagated further.
2175struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2176 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2177
2178 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2179 PatternRewriter &rewriter) const override {
2180 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2181 if (!canFoldIntoConsumerOp(castOp))
2182 return failure();
2183
2184 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2185 SmallVector<ReassociationIndices, 4> reassoc =
2186 expandOp.getReassociationIndices();
2187
2188 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2189 SmallVector<Value> dynamicOutputShape;
2190 auto outputIt = expandOp.getOutputShape().begin();
2191
2192 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2193 for (uint64_t outDim : innerReassoc) {
2194 if (ShapedType::isStatic(newOutputShape[outDim]))
2195 continue;
2196
2197 // If the cast's src type is dynamic, don't infer any of the
2198 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2199 // least one of the expanded dimensions to be dynamic if the input is
2200 // dynamic.
2201 Value val = *outputIt;
2202 ++outputIt;
2203 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2204 dynamicOutputShape.push_back(val);
2205 continue;
2206 }
2207
2208 APInt cst;
2209 if (matchPattern(val, m_ConstantInt(&cst))) {
2210 newOutputShape[outDim] = cst.getSExtValue();
2211 } else {
2212 dynamicOutputShape.push_back(val);
2213 }
2214 }
2215 }
2216
2217 // Couldn't match any values, nothing to change
2218 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2219 return failure();
2220
2221 // Calculate the input shape from the output
2222 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2223 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2224 for (auto outDim : reassoc[inDim]) {
2225 auto ofr = newOutputShape[outDim];
2226 if (ShapedType::isDynamic(ofr)) {
2227 newInputShape[inDim] = ShapedType::kDynamic;
2228 break;
2229 }
2230 newInputShape[inDim] *= ofr;
2231 }
2232 }
2233
2234 SmallVector<OpFoldResult> outputOfr =
2235 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2236 auto inputType = RankedTensorType::get(
2237 newInputShape, expandOp.getSrcType().getElementType());
2238 auto outputType = RankedTensorType::get(
2239 newOutputShape, expandOp.getSrcType().getElementType());
2240 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2241 expandOp.getSrc());
2242 auto newExpand = ExpandShapeOp::create(
2243 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2244 expandOp.getReassociationIndices(), outputOfr);
2245 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2246 newExpand.getResult());
2247 return success();
2248 }
2249};
2250} // namespace
2251
2252void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2253 MLIRContext *context) {
2254 results.add<
2255 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2256 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2257 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2258 FoldReshapeWithSplat<ExpandShapeOp>,
2259 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2260}
2261
2262void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2263 MLIRContext *context) {
2264 results.add<
2265 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2266 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2267 tensor::DimOp, RankedTensorType>,
2268 FoldReshapeWithConstant<CollapseShapeOp>,
2269 FoldReshapeWithSplat<CollapseShapeOp>,
2270 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2271 context);
2272}
2273
2274OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2276 adaptor.getOperands());
2277}
2278
2279OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2281 adaptor.getOperands());
2282}
2283
2284//===----------------------------------------------------------------------===//
2285// ExtractSliceOp
2286//===----------------------------------------------------------------------===//
2287
2288void ExtractSliceOp::getAsmResultNames(
2289 function_ref<void(Value, StringRef)> setNameFn) {
2290 setNameFn(getResult(), "extracted_slice");
2291}
2292
2293/// An extract_slice result type can be inferred, when it is not
2294/// rank-reduced, from the source type and the static representation of
2295/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2296RankedTensorType ExtractSliceOp::inferResultType(
2297 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2298 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2299 // An extract_slice op may specify only a leading subset of offset/sizes/
2300 // strides in which case we complete with offset=0, sizes from memref type
2301 // and strides=1.
2302 assert(static_cast<int64_t>(staticSizes.size()) ==
2303 sourceTensorType.getRank() &&
2304 "unexpected staticSizes not equal to rank of source");
2305 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2306 sourceTensorType.getEncoding());
2307}
2308
2309// TODO: This uses neither offsets nor strides!
2310RankedTensorType ExtractSliceOp::inferResultType(
2311 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2312 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2313 SmallVector<int64_t> staticSizes;
2314 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2315 assert(static_cast<int64_t>(staticSizes.size()) ==
2316 sourceTensorType.getRank() &&
2317 "unexpected staticSizes not equal to rank of source");
2318 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2319 sourceTensorType.getEncoding());
2320}
2321
2322/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2323/// number of sizes), drop as many size 1 as needed to produce an inferred
2324/// type with the desired rank.
2325///
2326/// Note that there may be multiple ways to compute this rank-reduced type:
2327/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2328///
2329/// To disambiguate, this function always drops the first 1 sizes occurrences.
2330RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2331 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2332 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2333 ArrayRef<int64_t> strides) {
2334 // Type inferred in the absence of rank-reducing behavior.
2335 auto inferredType = llvm::cast<RankedTensorType>(
2336 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2337 int rankDiff = inferredType.getRank() - desiredResultRank;
2338 if (rankDiff > 0) {
2339 auto shape = inferredType.getShape();
2340 llvm::SmallBitVector dimsToProject =
2341 getPositionsOfShapeOne(rankDiff, shape);
2342 SmallVector<int64_t> projectedShape;
2343 // Best effort rank-reducing: drop 1s in order.
2344 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2345 if (!dimsToProject.test(pos))
2346 projectedShape.push_back(shape[pos]);
2347 inferredType =
2348 RankedTensorType::get(projectedShape, inferredType.getElementType());
2349 }
2350 return inferredType;
2351}
2352
2353RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2354 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2355 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2356 ArrayRef<OpFoldResult> strides) {
2357 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2358 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2359 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2360 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2361 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2362 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2363 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2364 staticStrides);
2365}
2366
2367/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2368/// result type. If the type passed is nullptr, it is inferred.
2369void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2370 RankedTensorType resultType, Value source,
2371 ArrayRef<OpFoldResult> offsets,
2372 ArrayRef<OpFoldResult> sizes,
2373 ArrayRef<OpFoldResult> strides,
2374 ArrayRef<NamedAttribute> attrs) {
2375 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2376 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2377 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2378 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2379 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2380 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2381 // Structuring implementation this way avoids duplication between builders.
2382 if (!resultType) {
2383 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2384 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2385 }
2386 result.addAttributes(attrs);
2387 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2388 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2389 b.getDenseI64ArrayAttr(staticSizes),
2390 b.getDenseI64ArrayAttr(staticStrides));
2391}
2392
2393/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2394/// result type.
2395void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2396 ArrayRef<OpFoldResult> offsets,
2397 ArrayRef<OpFoldResult> sizes,
2398 ArrayRef<OpFoldResult> strides,
2399 ArrayRef<NamedAttribute> attrs) {
2400 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2401}
2402
2403/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2404/// a Range vector.
2405void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2406 ArrayRef<Range> ranges,
2407 ArrayRef<NamedAttribute> attrs) {
2408 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2409 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2410}
2411
2412/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2413/// the type passed is nullptr, it is inferred.
2414void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2415 RankedTensorType resultType, Value source,
2416 ValueRange offsets, ValueRange sizes,
2417 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2418 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2419 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2420 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2421 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2422 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2423 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2424 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2425}
2426
2427/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2428void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2429 ValueRange offsets, ValueRange sizes,
2430 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2431 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2432}
2433
2435 Operation *op,
2436 RankedTensorType expectedType) {
2437 switch (result) {
2439 return success();
2441 return op->emitError("expected rank to be smaller or equal to ")
2442 << "the other rank. ";
2444 return op->emitError("expected type to be ")
2445 << expectedType << " or a rank-reduced version. (size mismatch) ";
2447 return op->emitError("expected element type to be ")
2448 << expectedType.getElementType();
2449 default:
2450 llvm_unreachable("unexpected extract_slice op verification result");
2451 }
2452}
2453
2454/// Verifier for ExtractSliceOp.
2455LogicalResult ExtractSliceOp::verify() {
2456 RankedTensorType sourceType = getSourceType();
2457
2458 // Verify result type against inferred type.
2459 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2460 sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2463 return produceSliceErrorMsg(result, *this, expectedType);
2464
2465 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2466 // to the source tensor.
2467 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2468 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2469 getStaticStrides(), /*generateErrorMessage=*/true);
2470 if (!boundsResult.isValid)
2471 return getOperation()->emitError(boundsResult.errorMessage);
2472
2473 return success();
2474}
2475
2476llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2477 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2478}
2479
2480FailureOr<Value>
2481ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2482 ArrayRef<int64_t> desiredShape) {
2483 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2484 assert(sourceTensorType && "not a ranked tensor type");
2485 auto sourceShape = sourceTensorType.getShape();
2486 if (sourceShape.equals(desiredShape))
2487 return value;
2488 auto maybeRankReductionMask =
2489 mlir::computeRankReductionMask(sourceShape, desiredShape);
2490 if (!maybeRankReductionMask)
2491 return failure();
2493 b, loc, value,
2494 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2495}
2496
2497LogicalResult ExtractSliceOp::reifyResultShapes(
2498 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2499 reifiedReturnShapes.resize(1);
2500 reifiedReturnShapes[0].reserve(getType().getRank());
2501 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2502 llvm::SmallBitVector droppedDims = getDroppedDims();
2503 for (const auto &size : enumerate(mixedSizes)) {
2504 if (droppedDims.test(size.index()))
2505 continue;
2506 reifiedReturnShapes[0].push_back(size.value());
2507 }
2508 return success();
2509}
2510
2511namespace {
2512/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2513/// This essentially pushes memref_cast past its consuming slice when
2514/// `canFoldIntoConsumerOp` is true.
2515///
2516/// Example:
2517/// ```
2518/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2519/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2520/// tensor<3x4xf32>
2521/// ```
2522/// is rewritten into:
2523/// ```
2524/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2525/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2526/// ```
2527class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2528public:
2529 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2530
2531 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2532 PatternRewriter &rewriter) const override {
2533 // Any constant operand, just return to let the constant folder kick in.
2534 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2535 return matchPattern(operand, matchConstantIndex());
2536 }))
2537 return failure();
2538
2539 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2540 if (!castOp)
2541 return failure();
2542
2543 if (!canFoldIntoConsumerOp(castOp))
2544 return failure();
2545
2546 // Pattern does not apply if the produced op would not verify.
2547 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2548 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2549 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2550 sliceOp.getStaticStrides());
2551 if (!sliceResult.isValid)
2552 return failure();
2553
2554 // Create folded extract.
2555 Location loc = sliceOp.getLoc();
2556 Value newResult = ExtractSliceOp::create(
2557 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2558 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2559 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2560 sliceOp.getStaticStrides());
2561 rewriter.replaceOp(sliceOp, newResult);
2562 return success();
2563 }
2564};
2565
2566/// Slice elements from `values` into `outValues`. `counts` represents the
2567/// numbers of elements to stride in the original values for each dimension.
2568/// The output values can be used to construct a DenseElementsAttr.
2569template <typename IterTy, typename ElemTy>
2570static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2571 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2572 ArrayRef<int64_t> strides,
2573 llvm::SmallVectorImpl<ElemTy> *outValues) {
2574 assert(offsets.size() == sizes.size());
2575 assert(offsets.size() == strides.size());
2576 if (offsets.empty())
2577 return;
2578
2579 int64_t offset = offsets.front();
2580 int64_t size = sizes.front();
2581 int64_t stride = strides.front();
2582 if (offsets.size() == 1) {
2583 for (int64_t i = 0; i < size; ++i, offset += stride)
2584 outValues->push_back(*(values + offset));
2585
2586 return;
2587 }
2588
2589 for (int64_t i = 0; i < size; ++i, offset += stride) {
2590 auto begin = values + offset * counts.front();
2591 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2592 offsets.drop_front(), sizes.drop_front(),
2593 strides.drop_front(), outValues);
2594 }
2595}
2596
2597/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2598/// folded operation might introduce more constant data; Users can control
2599/// their heuristics by the control function.
2600class ConstantOpExtractSliceFolder final
2601 : public OpRewritePattern<ExtractSliceOp> {
2602public:
2603 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2604
2605 ConstantOpExtractSliceFolder(MLIRContext *context,
2607 : OpRewritePattern<ExtractSliceOp>(context),
2608 controlFn(std::move(controlFn)) {}
2609
2610 LogicalResult matchAndRewrite(ExtractSliceOp op,
2611 PatternRewriter &rewriter) const override {
2612 DenseElementsAttr attr;
2613 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2614 return failure();
2615
2616 // A constant splat is handled by fold().
2617 if (attr.isSplat())
2618 return failure();
2619
2620 // Dynamic result shape is not supported.
2621 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2622 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2623 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2624 return failure();
2625
2626 // Customized control over the folding.
2627 if (!controlFn(op))
2628 return failure();
2629
2630 int64_t count = sourceType.getNumElements();
2631 if (count == 0)
2632 return failure();
2633
2634 // Check if there are any dynamic parts, which are not supported.
2635 auto offsets = op.getStaticOffsets();
2636 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2637 return failure();
2638 auto sizes = op.getStaticSizes();
2639 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2640 return failure();
2641 auto strides = op.getStaticStrides();
2642 if (llvm::is_contained(strides, ShapedType::kDynamic))
2643 return failure();
2644
2645 // Compute the stride for each dimension.
2646 SmallVector<int64_t> counts;
2647 ArrayRef<int64_t> shape = sourceType.getShape();
2648 counts.reserve(shape.size());
2649 for (int64_t v : shape) {
2650 count = count / v;
2651 counts.push_back(count);
2652 }
2653
2654 // New attribute constructed by the sliced values.
2655 DenseElementsAttr newAttr;
2656
2657 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2658 SmallVector<APInt> outValues;
2659 outValues.reserve(sourceType.getNumElements());
2660 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2661 elems.begin(), counts, offsets, sizes, strides, &outValues);
2662 newAttr = DenseElementsAttr::get(resultType, outValues);
2663 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2664 SmallVector<APFloat> outValues;
2665 outValues.reserve(sourceType.getNumElements());
2666 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2667 elems.begin(), counts, offsets, sizes, strides, &outValues);
2668 newAttr = DenseElementsAttr::get(resultType, outValues);
2669 }
2670
2671 if (newAttr) {
2672 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2673 return success();
2674 }
2675
2676 return failure();
2677 }
2678
2679private:
2680 /// This additionally controls whether the fold happens or not. Users can
2681 /// impose their heuristics in the function.
2683};
2684
2685} // namespace
2686
2689 const ControlConstantExtractSliceFusionFn &controlFn) {
2690 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2691}
2692
2693/// Return the canonical type of the result of an extract_slice op.
2695 RankedTensorType operator()(ExtractSliceOp op,
2696 ArrayRef<OpFoldResult> mixedOffsets,
2697 ArrayRef<OpFoldResult> mixedSizes,
2698 ArrayRef<OpFoldResult> mixedStrides) {
2699 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2700 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2701 mixedStrides);
2702 }
2703};
2704
2705/// A canonicalizer wrapper to replace ExtractSliceOps.
2707 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2708 ExtractSliceOp newOp) {
2709 Value replacement = newOp.getResult();
2710 if (replacement.getType() != op.getType())
2711 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2712 replacement);
2713 rewriter.replaceOp(op, replacement);
2714 }
2715};
2716
2717void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2718 MLIRContext *context) {
2719 results.add<
2720 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2721 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2722 ExtractSliceOpCastFolder>(context);
2723}
2724
2725//
2726static LogicalResult
2727foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2728 ShapedType shapedType) {
2729 OpBuilder b(op.getContext());
2730 for (OpFoldResult ofr : op.getMixedOffsets())
2731 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2732 return failure();
2733 // Rank-reducing noops only need to inspect the leading dimensions:
2734 // llvm::zip is appropriate.
2735 auto shape = shapedType.getShape();
2736 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2737 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2738 return failure();
2739 for (OpFoldResult ofr : op.getMixedStrides())
2740 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2741 return failure();
2742 return success();
2743}
2744
2745/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2746/// slice, we can return the InsertSliceOp's source directly.
2747// TODO: This only checks the immediate producer; extend to go up the
2748// insert/extract chain if the slices are disjoint.
2749static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2750 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2751
2752 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2753 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2754 insertOp.isSameAs(extractOp, isSame))
2755 return insertOp.getSource();
2756
2757 return {};
2758}
2759
2760OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2761 if (OpFoldResult reshapedSource = reshapeConstantSource(
2762 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2763 getResult().getType()))
2764 return reshapedSource;
2765 if (getSourceType() == getType() &&
2767 return this->getSource();
2768 if (Value slice = foldExtractAfterInsertSlice(*this))
2769 return slice;
2770
2771 return OpFoldResult();
2772}
2773
2775 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2776 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2777 unsigned rank = rankedTensorType.getRank();
2778 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2780 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2781 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2782 offsets, sizes, strides);
2783}
2784
2785//===----------------------------------------------------------------------===//
2786// InsertSliceOp
2787//===----------------------------------------------------------------------===//
2788
2789void InsertSliceOp::getAsmResultNames(
2790 function_ref<void(Value, StringRef)> setNameFn) {
2791 setNameFn(getResult(), "inserted_slice");
2792}
2793
2794// Build a InsertSliceOp with mixed static and dynamic entries.
2795void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2796 Value dest, ArrayRef<OpFoldResult> offsets,
2798 ArrayRef<OpFoldResult> strides,
2800 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2801 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2802 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2803 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2804 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2805 result.addAttributes(attrs);
2806 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2807 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2808 b.getDenseI64ArrayAttr(staticSizes),
2809 b.getDenseI64ArrayAttr(staticStrides));
2810}
2811
2812/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2813/// Range vector.
2814void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2815 Value dest, ArrayRef<Range> ranges,
2816 ArrayRef<NamedAttribute> attrs) {
2817 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2818 build(b, result, source, dest, offsets, sizes, strides, attrs);
2819}
2820
2821// Build a InsertSliceOp with dynamic entries.
2822void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2823 Value dest, ValueRange offsets, ValueRange sizes,
2824 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2825 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2826 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2827 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2828 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2829 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2830 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2831 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2832}
2833
2834/// Rank-reducing type verification for both InsertSliceOp and
2835/// ParallelInsertSliceOp.
2837 RankedTensorType srcType, RankedTensorType dstType,
2838 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2839 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2840 // insert_slice is the inverse of extract_slice, use the same type
2841 // inference.
2842 RankedTensorType expected = ExtractSliceOp::inferResultType(
2843 dstType, staticOffsets, staticSizes, staticStrides);
2844 if (expectedType)
2845 *expectedType = expected;
2846 return isRankReducedType(expected, srcType);
2847}
2848
2849/// Verifier for InsertSliceOp.
2850LogicalResult InsertSliceOp::verify() {
2851 // Verify result type against inferred type.
2852 RankedTensorType expectedType;
2854 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2855 getStaticSizes(), getStaticStrides(), &expectedType);
2857 return produceSliceErrorMsg(result, *this, expectedType);
2858
2859 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2860 // to the destination tensor.
2861 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2862 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2863 getStaticStrides(), /*generateErrorMessage=*/true);
2864 if (!boundsResult.isValid)
2865 return getOperation()->emitError(boundsResult.errorMessage);
2866
2867 return success();
2868}
2869
2870/// If we have two consecutive InsertSliceOp writing to the same slice, we
2871/// can mutate the second InsertSliceOp's destination to the first one's.
2872///
2873/// Example:
2874///
2875/// ```mlir
2876/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2877/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2878/// ```
2879///
2880/// folds into:
2881///
2882/// ```mlir
2883/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2884/// ```
2885///
2886/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2887static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2888 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2889
2890 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2891 if (!prevInsertOp ||
2892 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2893 !prevInsertOp.isSameAs(insertOp, isSame))
2894 return failure();
2895
2896 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2897 return success();
2898}
2899
2900/// Folds round-trip extract/insert slice op pairs.
2901/// Example:
2902/// ```mlir
2903/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2904/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2905/// ```
2906/// can be folded into %val.
2907static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2908 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2909
2910 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2911 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2912 !extractOp.isSameAs(insertOp, isSame))
2913 return nullptr;
2914
2915 return extractOp.getSource();
2916}
2917
2918OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2919 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2920 getSourceType() == getType() &&
2922 return this->getSource();
2923 if (succeeded(foldInsertAfterInsertSlice(*this)))
2924 return getResult();
2925 if (auto result = foldInsertAfterExtractSlice(*this))
2926 return result;
2927 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2928 return getDest();
2929 return OpFoldResult();
2930}
2931
2932LogicalResult InsertSliceOp::reifyResultShapes(
2933 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2934 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2935 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2936 return success();
2937}
2938
2939namespace {
2940/// Pattern to rewrite a insert_slice op with constant arguments.
2941///
2942/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2943template <typename InsertOpTy>
2944class InsertSliceOpConstantArgumentFolder final
2945 : public OpRewritePattern<InsertOpTy> {
2946public:
2947 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2948
2949 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2950 PatternRewriter &rewriter) const override {
2951 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2952 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2953 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2954
2955 // No constant operands were folded, just return;
2956 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2957 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2958 failed(foldDynamicStrideList(mixedStrides)))
2959 return failure();
2960
2961 // Pattern does not apply if the produced op would not verify.
2962 SliceBoundsVerificationResult sliceResult =
2963 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2964 mixedOffsets, mixedSizes, mixedStrides);
2965 if (!sliceResult.isValid)
2966 return failure();
2967
2968 // Create the new op in canonical form.
2969 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2970 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2971 mixedOffsets, mixedSizes, mixedStrides);
2972 Value toInsert = insertSliceOp.getSource();
2973 if (sourceType != insertSliceOp.getSourceType()) {
2974 OpBuilder::InsertionGuard g(rewriter);
2975 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2976 // is that the insertion point is just before the InParallelOp in
2977 // the parallel case.
2978 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2979 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2980 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2981 sourceType, toInsert);
2982 }
2983 rewriter.replaceOpWithNewOp<InsertOpTy>(
2984 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2985 mixedSizes, mixedStrides);
2986 return success();
2987 }
2988};
2989
2990/// Fold tensor_casts with insert_slice operations. If the source or
2991/// destination tensor is a tensor_cast that removes static type information,
2992/// the cast is folded into the insert_slice operation. E.g.:
2993///
2994/// ```mlir
2995/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2996/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2997/// ```
2998///
2999/// folds into:
3000///
3001/// ```mlir
3002/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3003/// ```
3004///
3005/// Note: When folding a cast on the destination tensor, the result of the
3006/// insert_slice operation is casted to ensure that the type of the result did
3007/// not change.
3008///
3009/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3010template <typename InsertOpTy>
3011struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3012 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3013
3014 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3015 PatternRewriter &rewriter) const override {
3016 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3017 return matchPattern(operand, matchConstantIndex());
3018 }))
3019 return failure();
3020
3021 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3022 auto castOp = v.getDefiningOp<tensor::CastOp>();
3023 if (!castOp || !canFoldIntoConsumerOp(castOp))
3024 return std::nullopt;
3025 return castOp.getSource();
3026 };
3027 std::optional<Value> sourceCastSource =
3028 getSourceOfCastOp(insertSliceOp.getSource());
3029 std::optional<Value> destCastSource =
3030 getSourceOfCastOp(insertSliceOp.getDest());
3031 if (!sourceCastSource && !destCastSource)
3032 return failure();
3033
3034 auto src =
3035 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3036 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3037 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3038 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3039 if (!srcType || !dstType)
3040 return failure();
3041
3042 // The tensor.cast source could have additional static information not seen
3043 // in the insert slice op static sizes, so we ignore dynamic dims when
3044 // computing the rank reduction mask.
3045 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3046 auto rankReductionMask = computeRankReductionMask(
3047 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3048 if (!rankReductionMask.has_value())
3049 return failure();
3050 // Replace dimensions in the insert slice op with corresponding static dims
3051 // from the cast source type. If the insert slice sizes have static dims
3052 // that are not static in the tensor.cast source (i.e., when the cast op
3053 // casts a dynamic dim to static), the dim should not be replaced, and the
3054 // pattern will fail later in `verifyInsertSliceOp`.
3055 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3056 int64_t rankReducedIdx = 0;
3057 for (auto [idx, size] : enumerate(staticSizes)) {
3058 if (!rankReductionMask.value().contains(idx) &&
3059 !srcType.isDynamicDim(rankReducedIdx)) {
3060 mixedSizes[idx] = getAsIndexOpFoldResult(
3061 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3062 size = srcType.getDimSize(rankReducedIdx++);
3063 }
3064 }
3065
3066 // Pattern does not apply if the produced op would not verify.
3067 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3068 staticSizes, insertSliceOp.getStaticStrides()) !=
3069 SliceVerificationResult::Success)
3070 return failure();
3071 SliceBoundsVerificationResult sliceResult =
3072 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3073 mixedSizes, insertSliceOp.getMixedStrides());
3074 if (!sliceResult.isValid)
3075 return failure();
3076
3077 Operation *replacement =
3078 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3079 insertSliceOp.getMixedOffsets(), mixedSizes,
3080 insertSliceOp.getMixedStrides());
3081
3082 // In the parallel case there is no result and so nothing to cast.
3083 bool isParallelInsert =
3084 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3085 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3086 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3087 insertSliceOp.getDestType(),
3088 replacement->getResult(0));
3089 }
3090 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3091 return success();
3092 }
3093};
3094
3095/// If additional static type information can be deduced from a insert_slice's
3096/// size operands, insert an explicit cast of the op's source operand. This
3097/// enables other canonicalization patterns that are matching for tensor_cast
3098/// ops such as `ForOpTensorCastFolder` in SCF.
3099///
3100/// Example:
3101///
3102/// ```mlir
3103/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3104/// : tensor<?x?xf32> into ...
3105/// ```
3106///
3107/// folds into:
3108///
3109/// ```mlir
3110/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3111/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3112/// : tensor<64x64xf32> into ...
3113/// ```
3114///
3115/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3116template <typename InsertOpTy>
3117struct InsertSliceOpSourceCastInserter final
3118 : public OpRewritePattern<InsertOpTy> {
3119 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3120
3121 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3122 PatternRewriter &rewriter) const override {
3123 RankedTensorType srcType = insertSliceOp.getSourceType();
3124 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3125 return failure();
3126 SmallVector<int64_t> newSrcShape(srcType.getShape());
3127 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3128 if (std::optional<int64_t> constInt =
3129 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3130 // Bail on invalid IR.
3131 if (*constInt < 0)
3132 return failure();
3133 newSrcShape[i] = *constInt;
3134 }
3135 }
3136 if (!hasValidSizesOffsets(newSrcShape))
3137 return failure();
3138
3139 RankedTensorType newSrcType = RankedTensorType::get(
3140 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3141 if (srcType == newSrcType ||
3142 !preservesStaticInformation(srcType, newSrcType) ||
3143 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3144 return failure();
3145
3146 // newSrcType is:
3147 // 1) Different from srcType.
3148 // 2) "More static" than srcType.
3149 // 3) Cast-compatible with srcType.
3150 // Insert the cast.
3151 OpBuilder::InsertionGuard g(rewriter);
3152 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3153 // that the insertion point is just before the InParallelOp in the
3154 // parallel case.
3155 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3156 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3157 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3158 newSrcType, insertSliceOp.getSource());
3159 rewriter.replaceOpWithNewOp<InsertOpTy>(
3160 insertSliceOp, cast, insertSliceOp.getDest(),
3161 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3162 insertSliceOp.getMixedStrides());
3163 return success();
3164 }
3165};
3166} // namespace
3167
3168llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3169 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3170}
3171
3172void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3173 MLIRContext *context) {
3174 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3175 InsertSliceOpCastFolder<InsertSliceOp>,
3176 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3177}
3178
3180 Location loc,
3181 Value tensor,
3182 Value dest) {
3183 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3184 unsigned rank = rankedTensorType.getRank();
3185 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3186 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3187 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3188 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3189 sizes, strides);
3190}
3191
3192//===----------------------------------------------------------------------===//
3193// PadOp
3194//===----------------------------------------------------------------------===//
3195
3196void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3197 setNameFn(getResult(), "padded");
3198}
3199
3200LogicalResult PadOp::verify() {
3201 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3202 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3203 auto expectedType =
3204 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3205 if (!expectedType) {
3206 return emitError("failed to infer expectedType from sourceType ")
3207 << sourceType << ", specified resultType is " << resultType;
3208 }
3209 if (resultType.getRank() != expectedType.getRank()) {
3210 return emitError("specified type ")
3211 << resultType << " does not match the inferred type "
3212 << expectedType;
3213 }
3214 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3215 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3216 continue;
3217 if (expectedType.isDynamicDim(i))
3218 continue;
3219 return emitError("specified type ")
3220 << resultType << " does not match the inferred type "
3221 << expectedType;
3222 }
3223
3224 return success();
3225}
3226
3227LogicalResult PadOp::verifyRegions() {
3228 auto &region = getRegion();
3229 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3230 Block &block = region.front();
3231 if (block.getNumArguments() != rank)
3232 return emitError("expected the block to have ") << rank << " arguments";
3233
3234 // Note: the number and type of yield values are checked in the YieldOp.
3235 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3236 if (!en.value().isIndex())
3237 return emitOpError("expected block argument ")
3238 << (en.index() + 1) << " to be an index";
3239 }
3240
3241 // Ensure that the region yields an element of the right type.
3242 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3243 if (yieldOp.getValue().getType() !=
3244 llvm::cast<ShapedType>(getType()).getElementType())
3245 return emitOpError("expected yield type to match shape element type");
3246
3247 return success();
3248}
3249
3250RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3251 ArrayRef<int64_t> staticLow,
3252 ArrayRef<int64_t> staticHigh,
3253 ArrayRef<int64_t> resultShape) {
3254 unsigned rank = sourceType.getRank();
3255 if (staticLow.size() != rank)
3256 return RankedTensorType();
3257 if (staticHigh.size() != rank)
3258 return RankedTensorType();
3259 if (!resultShape.empty() && resultShape.size() != rank)
3260 return RankedTensorType();
3261
3262 SmallVector<int64_t, 4> inferredShape;
3263 for (auto i : llvm::seq<unsigned>(0, rank)) {
3264 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3265 staticHigh[i] == ShapedType::kDynamic) {
3266 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3267 : resultShape[i]);
3268 } else {
3269 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3270 assert((resultShape.empty() || size == resultShape[i] ||
3271 resultShape[i] == ShapedType::kDynamic) &&
3272 "mismatch between inferred shape and result shape");
3273 inferredShape.push_back(size);
3274 }
3275 }
3276
3277 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3278}
3279
3280void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3281 Value source, ArrayRef<int64_t> staticLow,
3282 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3283 bool nofold, ArrayRef<NamedAttribute> attrs) {
3284 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3285 if (!resultType)
3286 resultType = inferResultType(sourceType, staticLow, staticHigh);
3287 result.addAttributes(attrs);
3288 build(b, result, resultType, source, low, high,
3289 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3290 nofold ? b.getUnitAttr() : UnitAttr());
3291}
3292
3293void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3294 Value source, ValueRange low, ValueRange high, bool nofold,
3295 ArrayRef<NamedAttribute> attrs) {
3296 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3297 unsigned rank = sourceType.getRank();
3298 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3299 build(b, result, resultType, source, staticVector, staticVector, low, high,
3300 nofold, attrs);
3301}
3302
3303void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3304 Value source, ArrayRef<OpFoldResult> low,
3305 ArrayRef<OpFoldResult> high, bool nofold,
3306 ArrayRef<NamedAttribute> attrs) {
3307 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3308 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3309 SmallVector<int64_t, 4> staticLow, staticHigh;
3310 // staticLow and staticHigh have full information of the padding config.
3311 // This will grow staticLow and staticHigh with 1 value. If the config is
3312 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3313 // value as well.
3314 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3315 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3316 if (!resultType) {
3317 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3318 }
3319 assert(llvm::isa<RankedTensorType>(resultType));
3320 result.addAttributes(attrs);
3321 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3322 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3323 nofold ? b.getUnitAttr() : UnitAttr());
3324}
3325
3326void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3327 Value source, ArrayRef<OpFoldResult> low,
3328 ArrayRef<OpFoldResult> high, Value constantPadValue,
3329 bool nofold, ArrayRef<NamedAttribute> attrs) {
3330 build(b, result, resultType, source, low, high, nofold, attrs);
3331
3332 // Add a region and a block to yield the pad value.
3333 Region *region = result.regions[0].get();
3334 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3335 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3336 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3337
3338 // `builder.createBlock` changes the insertion point within the block. Create
3339 // a guard to reset the insertion point of the builder after it is destroyed.
3340 OpBuilder::InsertionGuard guard(b);
3341 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3342 tensor::YieldOp::create(b, result.location, constantPadValue);
3343}
3344
3345llvm::SmallBitVector PadOp::getPaddedDims() {
3346 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3347 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3348 for (const auto &en : enumerate(paddingWidths))
3349 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3350 paddedDims.set(en.index());
3351 };
3352 extractPaddedDims(getMixedLowPad());
3353 extractPaddedDims(getMixedHighPad());
3354 return paddedDims;
3355}
3356
3357namespace {
3358// Folds tensor.pad when padding is static zeros and the attribute
3359// doesn't request otherwise.
3360struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3361 using OpRewritePattern<PadOp>::OpRewritePattern;
3362
3363 LogicalResult matchAndRewrite(PadOp padTensorOp,
3364 PatternRewriter &rewriter) const override {
3365 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3366 return failure();
3367 if (padTensorOp.getNofold())
3368 return failure();
3369 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3370 padTensorOp, padTensorOp.getResult().getType(),
3371 padTensorOp.getSource());
3372 return success();
3373 }
3374};
3375
3376// Fold CastOp into PadOp when adding static information.
3377struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3378 using OpRewritePattern<PadOp>::OpRewritePattern;
3379
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3381 PatternRewriter &rewriter) const override {
3382 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3383 if (!tensor::canFoldIntoConsumerOp(castOp))
3384 return failure();
3385
3386 auto newResultType = PadOp::inferResultType(
3387 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3388 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3389 padTensorOp.getResultType().getShape());
3390
3391 if (newResultType == padTensorOp.getResultType()) {
3392 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3393 padTensorOp.getSourceMutable().assign(castOp.getSource());
3394 });
3395 } else {
3396 auto newOp = PadOp::create(
3397 rewriter, padTensorOp->getLoc(), newResultType,
3398 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3399 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3400 padTensorOp.getHigh(), padTensorOp.getNofold(),
3401 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3402 IRMapping mapper;
3403 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3404
3405 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3406 padTensorOp, padTensorOp.getResultType(), newOp);
3407 }
3408 return success();
3409 }
3410};
3411
3412// Fold CastOp using the result of PadOp back into the latter if it adds
3413// static information.
3414struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3415 using OpRewritePattern<PadOp>::OpRewritePattern;
3416
3417 LogicalResult matchAndRewrite(PadOp padTensorOp,
3418 PatternRewriter &rewriter) const override {
3419 if (!padTensorOp.getResult().hasOneUse())
3420 return failure();
3421 auto tensorCastOp =
3422 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3423 if (!tensorCastOp)
3424 return failure();
3425 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3426 tensorCastOp.getDest().getType()))
3427 return failure();
3428
3429 auto replacementOp = PadOp::create(
3430 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3431 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3432 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3433 padTensorOp.getHigh(), padTensorOp.getNofold(),
3434 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3435 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3436
3437 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3438 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3439 return success();
3440 }
3441};
3442
3443/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3444/// different dimensions. The pattern applies if the following preconditions
3445/// hold:
3446/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3447/// 2) the tensor::ExtractSliceOps have only unit-strides,
3448/// 3) the tensor::PadOps perform only high-padding,
3449/// 4) the tensor::PadOps have the same constant padding value,
3450/// 5) the tensor::PadOps do not have common padding dimensions,
3451/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3452/// zero-offset for every dimension.
3453/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3454/// the
3455/// padded source dimensions.
3456///
3457/// Example:
3458///
3459/// ```mlir
3460/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3461/// : tensor<64x64xf32> to tensor<?x64xf32>
3462/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3463/// } : tensor<?x64xf32> to tensor<8x64xf32>
3464/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3465/// : tensor<8x64xf32> to tensor<8x?xf32>
3466/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3467/// } : tensor<8x?xf32> to tensor<8x4xf32>
3468/// ```
3469///
3470/// folds into:
3471///
3472/// ```mlir
3473/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3474/// : tensor<64x64xf32> to tensor<?x?xf32>
3475/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3476/// } : tensor<?x?xf32> to tensor<8x4xf32>
3477/// ```
3478struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3479 using OpRewritePattern<PadOp>::OpRewritePattern;
3480
3481 LogicalResult matchAndRewrite(PadOp padOp,
3482 PatternRewriter &rewriter) const override {
3483 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3484 if (!innerSliceOp)
3485 return failure();
3486 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3487 if (!outerPadOp || outerPadOp.getNofold())
3488 return failure();
3489 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3490 if (!outerSliceOp)
3491 return failure();
3492
3493 // 1) Fail if the chain is rank-reducing.
3494 int64_t rank = padOp.getSourceType().getRank();
3495 if (outerSliceOp.getSourceType().getRank() != rank) {
3496 return rewriter.notifyMatchFailure(padOp,
3497 "cannot fold rank-reducing chain");
3498 }
3499
3500 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3501 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3502 return rewriter.notifyMatchFailure(
3503 padOp, "cannot fold non-unit stride ExtractSliceOps");
3504 }
3505
3506 // 3) Fail if the tensor::PadOps have non-zero low padding.
3507 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3508 return rewriter.notifyMatchFailure(padOp,
3509 "cannot fold PadOps with low padding");
3510 }
3511
3512 // 4) Fail if the tensor::PadOps padding values do not match.
3513 Attribute innerAttr, outerAttr;
3514 Value innerValue = padOp.getConstantPaddingValue();
3515 Value outerValue = outerPadOp.getConstantPaddingValue();
3516 if (!innerValue || !outerValue ||
3517 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3518 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3519 innerAttr != outerAttr) {
3520 return rewriter.notifyMatchFailure(
3521 padOp, "cannot fold PadOps with different padding values");
3522 }
3523
3524 // 5) Fail if a dimension is padded by both tensor::PadOps.
3525 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3526 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3527 if (innerDims.anyCommon(outerDims)) {
3528 return rewriter.notifyMatchFailure(
3529 padOp, "cannot fold PadOps with common padding dimensions");
3530 }
3531
3532 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3533 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3534 // for every dimension, and use the offset the other pair. Fail if no
3535 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3536 // exists.
3537 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3538 for (auto en : enumerate(newOffsets)) {
3539 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3540 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3541 if (!innerDims.test(en.index()) &&
3542 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3543 en.value() = outerOffset;
3544 continue;
3545 }
3546 if (!outerDims.test(en.index()) &&
3547 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3548 en.value() = innerOffset;
3549 continue;
3550 }
3551 return rewriter.notifyMatchFailure(
3552 padOp, "cannot find zero-offset and zero-padding pair");
3553 }
3554
3555 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3556 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3557 // outer tensor::PadOp and fail if the size of the inner
3558 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3559 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3560 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3561 for (auto en : enumerate(newSizes)) {
3562 if (!outerDims.test(en.index()))
3563 continue;
3564 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3565 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3566 assert(ShapedType::isStatic(sourceSize) &&
3567 "expected padded dimension to have a static size");
3568 if (getConstantIntValue(sliceSize) != sourceSize) {
3569 return rewriter.notifyMatchFailure(
3570 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3571 "match the size of the outer padding");
3572 }
3573 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3574 }
3575
3576 // Combine the high paddings of the two tensor::PadOps.
3577 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3578 for (auto en : enumerate(newHighPad)) {
3579 if (innerDims.test(en.index()))
3580 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3581 if (outerDims.test(en.index()))
3582 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3583 }
3584
3585 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3586 // the two paddings in one step.
3587 auto newSliceOp = ExtractSliceOp::create(
3588 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3589 newSizes, innerSliceOp.getMixedStrides());
3590 auto newPadOp = PadOp::create(
3591 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3592 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3593 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3594 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3595 newPadOp.getRegion().begin());
3596 rewriter.replaceOp(padOp, newPadOp.getResult());
3597 return success();
3598 }
3599};
3600
3601struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3602 using OpRewritePattern<PadOp>::OpRewritePattern;
3603
3604 LogicalResult matchAndRewrite(PadOp padTensorOp,
3605 PatternRewriter &rewriter) const override {
3606 Value input = padTensorOp.getSource();
3607 if (!llvm::isa<RankedTensorType>(input.getType()))
3608 return failure();
3609 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3610 auto inputRank = inputDims.size();
3611
3612 auto oldResultType =
3613 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3614 if (!oldResultType)
3615 return failure();
3616
3617 auto outputDims = oldResultType.getShape();
3618
3619 // Extract the static info from the high and low operands.
3620 SmallVector<int64_t> constOperandsLow;
3621 SmallVector<Value> newLows;
3622 for (auto operand : padTensorOp.getLow()) {
3623 APSInt intOp;
3624 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3625 constOperandsLow.push_back(ShapedType::kDynamic);
3626 newLows.push_back(operand);
3627 continue;
3628 }
3629 constOperandsLow.push_back(intOp.getExtValue());
3630 }
3631 SmallVector<int64_t> constOperandsHigh;
3632 SmallVector<Value> newHighs;
3633 for (auto operand : padTensorOp.getHigh()) {
3634 APSInt intOp;
3635 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3636 constOperandsHigh.push_back(ShapedType::kDynamic);
3637 newHighs.push_back(operand);
3638 continue;
3639 }
3640 constOperandsHigh.push_back(intOp.getExtValue());
3641 }
3642
3643 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3644 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3645
3646 // Verify the op is well-formed.
3647 if (inputDims.size() != outputDims.size() ||
3648 inputDims.size() != constLow.size() ||
3649 inputDims.size() != constHigh.size())
3650 return failure();
3651
3652 auto lowCount = 0;
3653 auto highCount = 0;
3654 for (size_t i = 0; i < inputRank; i++) {
3655 if (constLow[i] == ShapedType::kDynamic)
3656 constLow[i] = constOperandsLow[lowCount++];
3657 if (constHigh[i] == ShapedType::kDynamic)
3658 constHigh[i] = constOperandsHigh[highCount++];
3659 }
3660
3661 auto staticLow = ArrayRef<int64_t>(constLow);
3662 auto staticHigh = ArrayRef<int64_t>(constHigh);
3663
3664 // Calculate the output sizes with the static information.
3665 SmallVector<int64_t> newOutDims;
3666 for (size_t i = 0; i < inputRank; i++) {
3667 if (outputDims[i] == ShapedType::kDynamic) {
3668 newOutDims.push_back(
3669 (staticLow[i] == ShapedType::kDynamic ||
3670 staticHigh[i] == ShapedType::kDynamic ||
3671 inputDims[i] == ShapedType::kDynamic
3672 ? ShapedType::kDynamic
3673 : inputDims[i] + staticLow[i] + staticHigh[i]));
3674 } else {
3675 newOutDims.push_back(outputDims[i]);
3676 }
3677 }
3678
3679 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3680 llvm::all_of(newOutDims,
3681 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3682 return failure();
3683
3684 // Rewrite the op using the new static type.
3685 auto newResultType = RankedTensorType::get(
3686 newOutDims, padTensorOp.getType().getElementType());
3687 auto newOp = PadOp::create(
3688 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3689 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3690 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3691
3692 IRMapping mapper;
3693 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3694 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3695 newOp);
3696
3697 return success();
3698 }
3699};
3700
3701/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3702///
3703/// Example:
3704///
3705/// ```mlir
3706/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3707/// tensor.yield %val
3708/// } : tensor<1x2xf32> to tensor<2x5xf32>
3709/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3710/// tensor.yield %val
3711/// } : tensor<1x5xf32> to tensor<5x7xf32>
3712/// ```
3713///
3714/// folds into:
3715///
3716/// ```mlir
3717/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3718/// tensor.yield %val
3719/// } : tensor<1x2xf32> to tensor<5x7xf32>
3720/// ```
3721struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3722 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3723
3724 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3725 PatternRewriter &rewriter) const override {
3726 if (padOp.getNofold()) {
3727 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3728 }
3729
3730 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3731 if (!producerPad || producerPad.getNofold()) {
3732 return rewriter.notifyMatchFailure(
3733 padOp, "producer is not a foldable tensor.pad op");
3734 }
3735
3736 // Fail if the tensor::PadOps padding values do not match.
3737 Value consumerPadValue = padOp.getConstantPaddingValue();
3738 Value producerPadValue = producerPad.getConstantPaddingValue();
3739 if (!consumerPadValue || !producerPadValue ||
3740 consumerPadValue != producerPadValue) {
3741 return rewriter.notifyMatchFailure(
3742 padOp,
3743 "cannot fold PadOps with different or non-constant padding values");
3744 }
3745
3746 Location loc = padOp.getLoc();
3747 AffineExpr d0, d1;
3748 bindDims(rewriter.getContext(), d0, d1);
3749
3750 // Combine the low/high paddings of the two tensor::PadOps.
3751 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3752 ArrayRef<OpFoldResult> producerPaddings) {
3753 SmallVector<OpFoldResult> sumPaddings;
3754 for (auto [consumerIndex, producerIndex] :
3755 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3756 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3757 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3758 }
3759 return sumPaddings;
3760 };
3761
3762 SmallVector<OpFoldResult> newHighPad =
3763 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3764 SmallVector<OpFoldResult> newLowPad =
3765 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3766
3767 auto newPadOp = tensor::PadOp::create(
3768 rewriter, padOp.getLoc(), padOp.getResultType(),
3769 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3770 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3771 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3772 newPadOp.getRegion().begin());
3773 rewriter.replaceOp(padOp, newPadOp.getResult());
3774 return success();
3775 }
3776};
3777
3778} // namespace
3779
3780LogicalResult
3781PadOp::reifyResultShapes(OpBuilder &b,
3782 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3783 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3784 SmallVector<OpFoldResult> lp = getMixedLowPad();
3785 SmallVector<OpFoldResult> hp = getMixedHighPad();
3786 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3787 if (!getType().isDynamicDim(i)) {
3788 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3789 continue;
3790 }
3791 Location loc = getLoc();
3792 Value dim = b.createOrFold<tensor::DimOp>(
3793 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3794
3795 AffineExpr d0, d1, d2;
3796 bindDims(b.getContext(), d0, d1, d2);
3797 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3798 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3799 }
3800 return success();
3801}
3802
3803void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3804 MLIRContext *context) {
3805 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3806 FoldOrthogonalPaddings, FoldStaticPadding,
3807 FoldConsecutiveConstantPadding>(context);
3808}
3809
3810/// Return the padding value of the PadOp if it constant. In this context,
3811/// "constant" means an actual constant or "defined outside of the block".
3812///
3813/// Values are considered constant in three cases:
3814/// - A ConstantLike value.
3815/// - A basic block argument from a different block.
3816/// - A value defined outside of the block.
3817///
3818/// If the padding value is not constant, an empty Value is returned.
3819Value PadOp::getConstantPaddingValue() {
3820 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3821 if (!yieldOp)
3822 return {};
3823 Value padValue = yieldOp.getValue();
3824 // Check if yield value is a constant.
3825 if (matchPattern(padValue, m_Constant()))
3826 return padValue;
3827 // Check if yield value is defined inside the PadOp block.
3828 if (padValue.getParentBlock() == &getRegion().front())
3829 return {};
3830 // Else: Yield value defined outside of the PadOp block.
3831 return padValue;
3832}
3833
3834OpFoldResult PadOp::fold(FoldAdaptor) {
3835 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3836 !getNofold())
3837 return getSource();
3838 return {};
3839}
3840
3841//===----------------------------------------------------------------------===//
3842// ParallelInsertSliceOp
3843//===----------------------------------------------------------------------===//
3844
3845OpResult ParallelInsertSliceOp::getTiedOpResult() {
3846 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3847 for (const auto &it :
3848 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3849 Operation &nextOp = it.value();
3850 if (&nextOp == getOperation())
3851 return parallelCombiningParent.getParentResult(it.index());
3852 }
3853 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3854}
3855
3856// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3857void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3858 Value source, Value dest,
3859 ArrayRef<OpFoldResult> offsets,
3860 ArrayRef<OpFoldResult> sizes,
3861 ArrayRef<OpFoldResult> strides,
3862 ArrayRef<NamedAttribute> attrs) {
3863 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3864 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3865 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3866 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3867 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3868 result.addAttributes(attrs);
3869 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3870 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3871 b.getDenseI64ArrayAttr(staticSizes),
3872 b.getDenseI64ArrayAttr(staticStrides));
3873}
3874
3875/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3876/// packed into a Range vector.
3877void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3878 Value source, Value dest,
3879 ArrayRef<Range> ranges,
3880 ArrayRef<NamedAttribute> attrs) {
3881 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3882 build(b, result, source, dest, offsets, sizes, strides, attrs);
3883}
3884
3885// Build a ParallelInsertSliceOp with dynamic entries.
3886void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3887 Value source, Value dest, ValueRange offsets,
3888 ValueRange sizes, ValueRange strides,
3889 ArrayRef<NamedAttribute> attrs) {
3890 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3891 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3892 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3893 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3894 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3895 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3896 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3897}
3898
3899LogicalResult ParallelInsertSliceOp::verify() {
3900 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3901 return this->emitError("expected InParallelOpInterface parent, got:")
3902 << *(getOperation()->getParentOp());
3903
3904 // Verify result type against inferred type.
3905 RankedTensorType expectedType;
3907 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3908 getStaticSizes(), getStaticStrides(), &expectedType);
3910 return produceSliceErrorMsg(result, *this, expectedType);
3911
3912 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3913 // to the destination tensor.
3914 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3915 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3916 getStaticStrides(), /*generateErrorMessage=*/true);
3917 if (!boundsResult.isValid)
3918 return getOperation()->emitError(boundsResult.errorMessage);
3919
3920 return success();
3921}
3922
3923void ParallelInsertSliceOp::getCanonicalizationPatterns(
3924 RewritePatternSet &results, MLIRContext *context) {
3925 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3926 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3927 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3928}
3929
3930llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3931 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3932}
3933
3934// ParallelCombiningOpInterface implementation.
3935MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3936 return getDestMutable();
3937}
3938
3939Operation *ParallelInsertSliceOp::getIteratingParent() {
3940 // Return the parent InParallelOpInterface's parent.
3941 if (auto combiningOp =
3942 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3943 return combiningOp->getParentOp();
3944 return nullptr;
3945}
3946
3947//===----------------------------------------------------------------------===//
3948// ScatterOp
3949//===----------------------------------------------------------------------===//
3950
3951void ScatterOp::getAsmResultNames(
3952 function_ref<void(Value, StringRef)> setNameFn) {
3953 setNameFn(getResult(), "scatter");
3954}
3955
3956LogicalResult ScatterOp::verify() {
3957 int64_t destRank = getDestType().getRank();
3958 ArrayRef<int64_t> scatterDims = getScatterDims();
3959 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3960 getIndicesType().getShape(), destRank,
3961 "scatter", "dest")))
3962 return failure();
3963
3964 if (!getUnique())
3965 return emitOpError("requires 'unique' attribute to be set");
3966 // TODO: we could also check statically that there are fewer leading index
3967 // tensor dims than the dest dims. If this is not the case, the unique
3968 // attribute cannot be true.
3969
3970 // Use the GatherOp::inferResultType on the `dest` type and verify the
3971 // expected type matches the source type.
3972 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3973 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3974 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3975 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3976 if (getSourceType() != expectedSourceType &&
3977 getSourceType() != expectedRankReducedSourceType) {
3978 return emitOpError("source type "
3979 "mismatch: "
3980 "expected ")
3981 << expectedSourceType << " or its rank-reduced variant "
3982 << expectedRankReducedSourceType << " (got: " << getSourceType()
3983 << ")";
3984 }
3985
3986 return success();
3987}
3988
3989//===----------------------------------------------------------------------===//
3990// SplatOp
3991//===----------------------------------------------------------------------===//
3992
3993void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3994 Type aggregateType, ValueRange dynamicSizes) {
3995 build(builder, result, aggregateType, element, dynamicSizes);
3996}
3997
3998void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3999 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4000 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4001 build(builder, result, aggregateType, element, dynamicSizes);
4002}
4003
4004void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4005 ArrayRef<OpFoldResult> sizes) {
4006 SmallVector<int64_t> staticShape;
4007 SmallVector<Value> dynamicSizes;
4008 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4009 build(builder, result, element, staticShape, dynamicSizes);
4010}
4011
4012void SplatOp::getAsmResultNames(
4013 function_ref<void(Value, StringRef)> setNameFn) {
4014 setNameFn(getResult(), "splat");
4015}
4016
4017LogicalResult SplatOp::verify() {
4018 if (getType().getNumDynamicDims() != getDynamicSizes().size())
4019 return emitOpError("incorrect number of dynamic sizes, has ")
4020 << getDynamicSizes().size() << ", expected "
4021 << getType().getNumDynamicDims();
4022 return success();
4023}
4024
4025LogicalResult
4026SplatOp::reifyResultShapes(OpBuilder &builder,
4027 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4028 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4029 unsigned ctr = 0;
4030 for (int64_t i = 0; i < getType().getRank(); ++i) {
4031 if (getType().isDynamicDim(i)) {
4032 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4033 } else {
4034 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4035 }
4036 }
4037 return success();
4038}
4039
4040OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4041 auto constOperand = adaptor.getInput();
4042 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4043 return {};
4044
4045 // Do not fold if the splat is not statically shaped
4046 if (!getType().hasStaticShape())
4047 return {};
4048
4049 // SplatElementsAttr::get treats single value for second arg as being a
4050 // splat.
4051 return SplatElementsAttr::get(getType(), {constOperand});
4052}
4053
4054//===----------------------------------------------------------------------===//
4055// Common Canonicalizers and Folders.
4056//===----------------------------------------------------------------------===//
4057static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4058 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4059 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4060 // might need special handling of attached regions.
4061 if (isa<InsertSliceOp>(op.getOperation()) ||
4062 isa<LoopLikeOpInterface>(op.getOperation()))
4063 return false;
4064
4066}
4067
4068/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4069/// the `tensor.cast` has source that is more static than the consuming op.
4070///
4071/// Example:
4072/// ```mlir
4073/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4074/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4075/// ```
4076///
4077/// folds into:
4078///
4079/// ```mlir
4080/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4081/// ```
4082/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4083/// can add the pattern to their canonicalizers.
4085 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4087 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4088
4089 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4090 PatternRewriter &rewriter) const override {
4091
4092 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4093 // for that instead.
4094 if (!foldTensorCastPrecondition(op) ||
4095 isa<linalg::RelayoutOpInterface>(*op))
4096 return failure();
4097
4098 SmallVector<Type> newResultTypes(op->getResultTypes());
4099 SmallVector<Value> newOperands =
4100 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4101
4102 // Clone op
4103 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4104
4105 SmallVector<Value, 4> replacements;
4106 replacements.reserve(newOp->getNumResults());
4107 for (auto [oldResult, newResult] :
4108 llvm::zip(op->getResults(), newOp->getResults())) {
4109 if (newResult.getType() != oldResult.getType()) {
4110 replacements.push_back(tensor::CastOp::create(
4111 rewriter, op->getLoc(), oldResult.getType(), newResult));
4112 } else {
4113 replacements.push_back(newResult);
4114 }
4115 }
4116 rewriter.replaceOp(op, replacements);
4117
4118 return success();
4119 }
4120};
4121
4122//===----------------------------------------------------------------------===//
4123// TensorDialect
4124//===----------------------------------------------------------------------===//
4125
4126void TensorDialect::getCanonicalizationPatterns(
4127 RewritePatternSet &results) const {
4128 results.add<FoldTensorCastProducerOp>(getContext());
4129}
4130
4131//===----------------------------------------------------------------------===//
4132// TableGen'd op method definitions
4133//===----------------------------------------------------------------------===//
4134
4135#define GET_OP_CLASSES
4136#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:75
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:90
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.