MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/Repeated.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/MathExtras.h"
42#include <optional>
43
44using namespace mlir;
45using namespace mlir::tensor;
46
47/// Materialize a single constant operation from a given attribute value with
48/// the desired resultant type.
49Operation *TensorDialect::materializeConstant(OpBuilder &builder,
50 Attribute value, Type type,
51 Location loc) {
52 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 return op;
54 if (complex::ConstantOp::isBuildableWith(value, type))
55 return complex::ConstantOp::create(builder, loc, type,
56 llvm::cast<ArrayAttr>(value));
57 return nullptr;
58}
59
61 int64_t dim) {
62 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
65
66 return builder.getIndexAttr(tensorType.getDimSize(dim));
67}
68
70 Location loc, Value value) {
71 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
73 for (int64_t i = 0; i < tensorType.getRank(); ++i)
74 result.push_back(getMixedSize(builder, loc, value, i));
75 return result;
76}
77
79 OpResult opResult) {
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
81 assert(tensorType && "expected tensor type");
82
83 // If the op has a destination, it implements DestinationStyleOpInterface and
84 // we can query the destination operand from that interface.
85 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
86 if (destOp)
87 return destOp.getTiedOpOperand(opResult)->get();
88
89 // Otherwise, create a new destination tensor with the same shape.
91 b.setInsertionPoint(opResult.getDefiningOp());
92
93 // Compute sizes.
95 if (!tensorType.hasStaticShape()) {
96 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
97 ReifiedRankedShapedTypeDims reifiedShapes;
98 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
99 return failure();
100 mixedSizes = reifiedShapes[opResult.getResultNumber()];
101 } else {
102 // Static shape: Take static sizes directly.
103 for (int64_t sz : tensorType.getShape())
104 mixedSizes.push_back(b.getIndexAttr(sz));
105 }
106
107 // Create empty tensor with the same encoding as the result type.
108 Attribute encoding;
109 if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType))
110 encoding = rankedTensorType.getEncoding();
111 Value emptyTensor = tensor::EmptyOp::create(
112 b, loc, mixedSizes, tensorType.getElementType(), encoding);
113 return emptyTensor;
114}
115
117 Operation *op,
119 for (OpResult opResult : op->getResults()) {
120 if (llvm::isa<TensorType>(opResult.getType())) {
121 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
122 if (failed(destination))
123 return failure();
124 result.push_back(*destination);
125 }
126 }
127 return success();
128}
129
131 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
132 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
133 return rtp1.getShape() == rtp2.getShape() &&
134 rtp1.getElementType() == rtp2.getElementType();
135 return false;
136 }
137 return tp1 == tp2; // default implementation
138}
139
140/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
141/// rank-extending tensor.insert_slice op.
142static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
143 ArrayRef<OpFoldResult> mixedSizes) {
144 llvm::SmallBitVector droppedDims(mixedSizes.size());
145 int64_t shapePos = reducedShape.size() - 1;
146
147 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
148 size_t idx = mixedSizes.size() - size.index() - 1;
149 // Rank-reduced dims must have a static unit dimension.
150 bool isStaticUnitSize =
151 isa<Attribute>(size.value()) &&
152 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
153
154 if (shapePos < 0) {
155 // There are no more dims in the reduced shape. All remaining sizes must
156 // be rank-reduced dims.
157 assert(isStaticUnitSize && "expected unit dim");
158 droppedDims.set(idx);
159 continue;
160 }
161
162 // Dim is preserved if the size is not a static 1.
163 if (!isStaticUnitSize) {
164 --shapePos;
165 continue;
166 }
167
168 // Dim is preserved if the reduced shape dim is also 1.
169 if (reducedShape[shapePos] == 1) {
170 --shapePos;
171 continue;
172 }
173
174 // Otherwise: Dim is dropped.
175 droppedDims.set(idx);
176 }
177
178 assert(shapePos < 0 && "dimension mismatch");
179 return droppedDims;
180}
181
182/// Given a ranked tensor type and a range of values that defines its dynamic
183/// dimension sizes, turn all dynamic sizes that have a constant value into
184/// static dimension sizes.
185static RankedTensorType
186foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
187 SmallVector<Value> &foldedDynamicSizes) {
188 SmallVector<int64_t> staticShape(type.getShape());
189 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
190 "incorrect number of dynamic sizes");
191
192 // Compute new static and dynamic sizes.
193 unsigned ctr = 0;
194 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
195 if (type.isDynamicDim(i)) {
196 Value dynamicSize = dynamicSizes[ctr++];
197 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
198 if (cst.has_value()) {
199 // Dynamic size must be non-negative.
200 if (cst.value() < 0) {
201 foldedDynamicSizes.push_back(dynamicSize);
202 continue;
203 }
204 staticShape[i] = *cst;
205 } else {
206 foldedDynamicSizes.push_back(dynamicSize);
207 }
208 }
209 }
210
211 return RankedTensorType::get(staticShape, type.getElementType(),
212 type.getEncoding());
213}
214
215//===----------------------------------------------------------------------===//
216// BitcastOp
217//===----------------------------------------------------------------------===//
218
219bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
220 if (inputs.size() != 1 || outputs.size() != 1)
221 return false;
222 Type a = inputs.front(), b = outputs.front();
223 auto aT = dyn_cast<TensorType>(a);
224 auto bT = dyn_cast<TensorType>(b);
225 if (!aT || !bT)
226 return false;
227
228 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
229 return false;
230
231 return succeeded(verifyCompatibleShape(aT, bT));
232}
233
234namespace {
235
236/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
237/// operation.
238struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
239 using OpRewritePattern<BitcastOp>::OpRewritePattern;
240
241 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
242 PatternRewriter &rewriter) const final {
243 auto tensorBitcastOperand =
244 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
245 if (!tensorBitcastOperand)
246 return failure();
247
248 auto resultType = cast<TensorType>(tensorBitcast.getType());
249 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
250 tensorBitcastOperand.getOperand());
251 return success();
252 }
253};
254
255} // namespace
256
257void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
258 MLIRContext *context) {
259 results.add<ChainedTensorBitcast>(context);
260}
261
262//===----------------------------------------------------------------------===//
263// CastOp
264//===----------------------------------------------------------------------===//
265
266void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
267 setNameFn(getResult(), "cast");
268}
269
270/// Returns true if `target` is a ranked tensor type that preserves static
271/// information available in the `source` ranked tensor type.
273 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
274 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
275
276 // Requires RankedTensorType.
277 if (!sourceType || !targetType)
278 return false;
279
280 // Requires same elemental type.
281 if (sourceType.getElementType() != targetType.getElementType())
282 return false;
283
284 // Requires same rank.
285 if (sourceType.getRank() != targetType.getRank())
286 return false;
287
288 // Requires same encoding.
289 if (sourceType.getEncoding() != targetType.getEncoding())
290 return false;
291
292 // If cast is towards more static sizes along any dimension, don't fold.
293 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
294 if (ShapedType::isStatic(std::get<0>(t)) &&
295 ShapedType::isDynamic(std::get<1>(t)))
296 return false;
297 }
298
299 return true;
300}
301
302/// Determines whether tensor::CastOp casts to a more dynamic version of the
303/// source tensor. This is useful to fold a tensor.cast into a consuming op and
304/// implement canonicalization patterns for ops in different dialects that may
305/// consume the results of tensor.cast operations. Such foldable tensor.cast
306/// operations are typically inserted as `slice` ops and are canonicalized,
307/// to preserve the type compatibility of their uses.
308///
309/// Returns true when all conditions are met:
310/// 1. source and result are ranked tensors with same element type and rank.
311/// 2. the tensor type has more static information than the result
312///
313/// Example:
314/// ```mlir
315/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
316/// %2 = consumer %1 ... : tensor<?x?xf32> ...
317/// ```
318///
319/// folds into:
320///
321/// ```mlir
322/// %2 = consumer %0 ... : tensor<8x16xf32> ...
323/// ```
325 if (!castOp)
326 return false;
327
328 // Can fold if the source of cast has at least as much static information as
329 // its results.
330 return preservesStaticInformation(castOp.getType(),
331 castOp.getSource().getType());
332}
333
334/// Determines whether the tensor::CastOp casts to a more static version of the
335/// source tensor. This is useful to fold into a producing op and implement
336/// canonicalization patterns with the `tensor.cast` op as the root, but
337/// producer being from different dialects. Returns true when all conditions are
338/// met:
339/// 1. source and result and ranked tensors with same element type and rank.
340/// 2. the result type has more static information than the source.
341///
342/// Example:
343/// ```mlir
344/// %1 = producer ... : tensor<?x?xf32>
345/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
346/// ```
347///
348/// can be canonicalized to :
349///
350/// ```mlir
351/// %2 = producer ... : tensor<8x16xf32>
352/// ```
353/// Not all ops might be canonicalizable this way, but for those that can be,
354/// this method provides a check that it is worth doing the canonicalization.
356 if (!castOp)
357 return false;
358 return preservesStaticInformation(castOp.getSource().getType(),
359 castOp.getType());
360}
361
363 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
364 if (llvm::isa<BlockArgument>(opOperand.get()))
365 return false;
366 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
367 return castOp && canFoldIntoConsumerOp(castOp);
368 });
369}
370
372 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
373 SmallVector<Value> newOperands;
374 newOperands.reserve(op->getNumOperands());
375
376 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
377
378 // Assumes that the result has dpsInits followed by nonDpsInits.
379 int64_t dpsInitIdx = 0;
380 for (OpOperand &opOperand : op->getOpOperands()) {
381 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
382 bool fold = canFoldIntoConsumerOp(tensorCastOp);
383 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
384 if (op.isDpsInit(&opOperand) &&
385 !llvm::isa<MemRefType>(newOperands.back().getType()))
386 newResTy[dpsInitIdx++] = newOperands.back().getType();
387 }
388 return newOperands;
389}
390
391/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
392/// that can be folded.
394 bool folded = false;
395 for (OpOperand &operand : op->getOpOperands()) {
396 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
397 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
398 operand.set(castOp.getOperand());
399 folded = true;
400 }
401 }
402 return success(folded);
403}
404
405bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
406 if (inputs.size() != 1 || outputs.size() != 1)
407 return false;
408 Type a = inputs.front(), b = outputs.front();
409 auto aT = llvm::dyn_cast<TensorType>(a);
410 auto bT = llvm::dyn_cast<TensorType>(b);
411 if (!aT || !bT)
412 return false;
413
414 if (aT.getElementType() != bT.getElementType())
415 return false;
416
417 return succeeded(verifyCompatibleShape(aT, bT));
418}
419
420/// Compute a TensorType that has the joined shape knowledge of the two
421/// given TensorTypes. The element types need to match.
423 assert(one.getElementType() == two.getElementType());
424
425 if (!one.hasRank())
426 return two;
427 if (!two.hasRank())
428 return one;
429
430 int64_t rank = one.getRank();
431 if (rank != two.getRank())
432 return {};
433
435 join.reserve(rank);
436 for (int64_t i = 0; i < rank; ++i) {
437 if (one.isDynamicDim(i)) {
438 join.push_back(two.getDimSize(i));
439 continue;
440 }
441 if (two.isDynamicDim(i)) {
442 join.push_back(one.getDimSize(i));
443 continue;
444 }
445 if (one.getDimSize(i) != two.getDimSize(i))
446 return {};
447 join.push_back(one.getDimSize(i));
448 }
449 return RankedTensorType::get(join, one.getElementType());
450}
451
452namespace {
453
454/// Replaces chains of two tensor.cast operations by a single tensor.cast
455/// operation if doing so does not remove runtime constraints.
456struct ChainedTensorCast : public OpRewritePattern<CastOp> {
457 using OpRewritePattern<CastOp>::OpRewritePattern;
458
459 LogicalResult matchAndRewrite(CastOp tensorCast,
460 PatternRewriter &rewriter) const final {
461 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
462
463 if (!tensorCastOperand)
464 return failure();
465
466 auto sourceType =
467 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
468 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
469 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
470
471 // We can remove the intermediate cast if joining all three produces the
472 // same result as just joining the source and result shapes.
473 auto firstJoin =
474 joinShapes(joinShapes(sourceType, intermediateType), resultType);
475
476 // The join might not exist if the cast sequence would fail at runtime.
477 if (!firstJoin)
478 return failure();
479
480 // The newJoin always exists if the above join exists, it might just contain
481 // less information. If so, we cannot drop the intermediate cast, as doing
482 // so would remove runtime checks.
483 auto newJoin = joinShapes(sourceType, resultType);
484 if (firstJoin != newJoin)
485 return failure();
486
487 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
488 tensorCastOperand.getOperand());
489 return success();
490 }
491};
492
493/// Fold tensor.cast into tesor.extract_slice producer.
494/// Example:
495/// ```
496/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
497/// tensor<128x512xf32> to tensor<?x512xf32>
498/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
499/// ```
500/// ->
501/// ```
502/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
503/// tensor<128x512xf32> to tensor<16x512xf32>
504/// ```
505struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
506 using OpRewritePattern<CastOp>::OpRewritePattern;
507
508 LogicalResult matchAndRewrite(CastOp tensorCast,
509 PatternRewriter &rewriter) const final {
510 auto extractOperand =
511 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
512
513 // Cannot fold cast to unranked tensor.
514 auto rankedResultType =
515 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
516 if (!rankedResultType)
517 return failure();
518
519 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
520 rankedResultType.getShape() ==
521 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
522 .getShape())
523 return failure();
524
525 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
526 auto dimMask = computeRankReductionMask(
527 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
528 size_t dimIndex = 0;
529 for (size_t i = 0, e = sizes.size(); i < e; i++) {
530 if (dimMask && dimMask->count(i))
531 continue;
532 int64_t dim = rankedResultType.getShape()[dimIndex++];
533 if (ShapedType::isDynamic(dim))
534 continue;
535 sizes[i] = rewriter.getIndexAttr(dim);
536 }
537
538 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
539 tensorCast, rankedResultType, extractOperand.getSource(),
540 extractOperand.getMixedOffsets(), sizes,
541 extractOperand.getMixedStrides());
542 return success();
543 }
544};
545
546} // namespace
547
548void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
549 MLIRContext *context) {
550 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
551}
552
553//===----------------------------------------------------------------------===//
554// ConcatOp
555//===----------------------------------------------------------------------===//
556
557RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
558 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
559 auto tensorTypes =
560 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
561 int64_t concatRank = tensorTypes[0].getRank();
562
563 // The concatenation dim must be in the range [0, rank).
564 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
565
566 SmallVector<int64_t> sizes(concatRank);
567 for (int64_t i = 0, e = concatRank; i < e; ++i) {
568 if (i == dim)
569 continue;
570 SaturatedInteger size;
571 for (auto tensorType : tensorTypes)
572 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
573 sizes[i] = size.asInteger();
574 }
575 auto concatSize = SaturatedInteger::wrap(0);
576 for (auto tensorType : tensorTypes)
577 concatSize =
578 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
579 sizes[dim] = concatSize.asInteger();
580 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
581}
582
583void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
584 ValueRange inputs) {
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputs.getTypes());
587 assert(succeeded(resultType) && "failed to infer concatenation result type");
588 build(builder, result, *resultType, dim, inputs);
589}
590
591LogicalResult ConcatOp::verify() {
592 if (getInputs().size() < 1)
593 return emitOpError("requires at least one input");
594
596 for (auto input : getInputs())
597 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
598
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
603 }))
604 return emitOpError("rank of concatenated inputs must match result rank");
605
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
609 }))
610 return emitOpError("inputs and result element type must match");
611
612 int64_t dim = getDim();
613 if (dim >= resultRank)
614 return emitOpError("concatenation dim must be less than the tensor rank");
615
616 SmallVector<int64_t> sizes(resultRank);
617 for (int64_t i = 0, e = resultRank; i < e; ++i) {
618 if (i == dim)
619 continue;
620 SaturatedInteger size;
621 for (auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
623 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
624 if (failed(maybeSize))
625 return emitOpError("static concatenation size mismatch along ")
626 << "non-concatenated dimension " << i;
627 size = *maybeSize;
628 }
629 sizes[i] = size.asInteger();
630 }
631 auto concatSize = SaturatedInteger::wrap(0);
632 for (auto tensorType : inputTypes)
633 concatSize =
634 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
637 RankedTensorType::get(sizes, inputTypes[0].getElementType());
638
639 for (auto [inferredSize, actualSize] :
640 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642 ShapedType::isDynamic(actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
644 return emitOpError("result type ")
645 << resultType << "does not match inferred shape "
646 << inferredResultType << " static sizes";
647 }
648
649 return success();
650}
651
652FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
655
657 inputShapes.reserve(numInputs);
658 SmallVector<OpFoldResult> concatOffsets;
659 concatOffsets.reserve(numInputs);
660 SmallVector<OpFoldResult> outputShape;
661
662 AffineExpr addExpr =
663 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
664 OpFoldResult zero = builder.getIndexAttr(0);
665 Location loc = getLoc();
666 for (auto [index, input] : llvm::enumerate(getInputs())) {
667 SmallVector<OpFoldResult> inputShape =
668 tensor::getMixedSizes(builder, input.getLoc(), input);
669 if (index == 0) {
670 outputShape = inputShape;
671 concatOffsets.push_back(zero);
672 } else {
673 concatOffsets.push_back(outputShape[concatDim]);
674 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
675 builder, loc, addExpr,
676 {outputShape[concatDim], inputShape[concatDim]});
677 }
678 inputShapes.emplace_back(std::move(inputShape));
679 }
680
681 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
683
684 int64_t rank = getType().getRank();
685 OpFoldResult one = builder.getIndexAttr(1);
686 SmallVector<OpFoldResult> strides(rank, one);
687 SmallVector<OpFoldResult> offsets(rank, zero);
688 for (auto [index, input] : llvm::enumerate(getInputs())) {
689 offsets[concatDim] = concatOffsets[index];
690 auto insertSlice = tensor::InsertSliceOp::create(
691 builder, loc, input, replacement, offsets, inputShapes[index], strides);
692 replacement = insertSlice.getResult();
693 }
694 if (replacement.getType() != getType()) {
695 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
696 }
698}
699
700LogicalResult
701ConcatOp::reifyResultShapes(OpBuilder &builder,
702 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
703 ValueRange inputs = getInputs();
704 int64_t dim = getDim();
705 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
706
707 Value init = inputs[0];
708 int64_t rank = getType().getRank();
709
710 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
711
712 // Pre-populate the result sizes with as much static information as possible
713 // from the given result type, as well as the inferred result type, otherwise
714 // use the dim sizes from the first input.
715 for (int64_t i = 0; i < rank; ++i) {
716 if (i == dim)
717 continue;
718 if (!getType().isDynamicDim(i)) {
719 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
720 } else if (!inferredResultType.isDynamicDim(i)) {
721 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
722 builder, getLoc(),
723 builder.getIndexAttr(inferredResultType.getDimSize(i)));
724 } else {
725 reifiedReturnShapes[0][i] =
726 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
727 }
728 }
729
730 if (getType().isDynamicDim(dim)) {
731 // Take the sum of the input sizes along the concatenated dim.
732 AffineExpr sum = builder.getAffineDimExpr(0);
734 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
735 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
736 sum = sum + builder.getAffineDimExpr(idx + 1);
737 sizes.push_back(
738 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
739 }
740 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
741 builder, getLoc(),
742 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
743 } else {
744 // If the result shape is static along the concatenated dim, use the static
745 // shape.
746 reifiedReturnShapes[0][dim] =
747 builder.getIndexAttr(getType().getDimSize(dim));
748 }
749 return success();
750}
751
752void ConcatOp::getAsmResultNames(
753 function_ref<void(Value, StringRef)> setNameFn) {
754 setNameFn(getResult(), "concat");
755}
756
757OpFoldResult ConcatOp::fold(FoldAdaptor) {
758 ValueRange inputs = getInputs();
759 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
760 return inputs[0];
761 return {};
762}
763
764namespace {
765/// Fold a concat op with a single input to a cast.
766struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
767 using OpRewritePattern<ConcatOp>::OpRewritePattern;
768
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
770 PatternRewriter &rewriter) const override {
771 if (concatOp.getInputs().size() != 1)
772 return failure();
773 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
774 concatOp.getInputs()[0]);
775 return success();
776 }
777};
778
779/// Propagate static shapes into the operands of a `tensor.concat`.
780///
781/// `tensor.concat` requires every operand to match on all dimensions except the
782/// concatenation dimension. If one operand is already static in those
783/// dimensions, the other operands may safely be refined to that same static
784/// shape.
785///
786/// Example:
787///
788/// ```mlir
789/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790/// tensor<?x12xi32>
791/// ```
792/// ->
793/// ```mlir
794/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795/// %2 = tensor.concat dim(0) %0, %cast :
796/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797/// ```
798struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
799 using OpRewritePattern<ConcatOp>::OpRewritePattern;
800
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
802 PatternRewriter &rewriter) const override {
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
806
807 // Find operands for which a more static shape can be inferred.
808 LogicalResult matched = failure();
809 // Inferred operand shapes are identical in every dimension except the
810 // concatenation dimension.
811 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812 for (auto [operandIdx, operandType] :
813 llvm::enumerate(concatOp->getOperandTypes())) {
814 // Compute inferred type for operand.
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(operandType).getDimSize(dim);
817 auto inferredOperandType = RankedTensorType::get(
818 inferredOperandShape, inferredResultType.getElementType());
819
820 // Check if inferred type is more static.
821 if (!preservesStaticInformation(inferredOperandType, operandType)) {
822 matched = success();
823
824 // Use refined operand type and create cast from original operand.
825 auto castOp =
826 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
827 concatOp.getOperand(operandIdx));
828 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
829 concatOp->setOperand(operandIdx, castOp->getResult(0));
830 });
831 }
832 }
833
834 return matched;
835 }
836};
837
838// Ensure `tensor.concat`'s result type is at least as static as can be inferred
839// from its operand types.
840///
841/// Example:
842/// ```mlir
843/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
844/// tensor<?x?xi32>
845/// ```
846/// ->
847/// ```mlir
848/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
849/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
850/// tensor<?x?xi32>
851/// ```
852struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
853 using OpRewritePattern<ConcatOp>::OpRewritePattern;
854
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
856 PatternRewriter &rewriter) const override {
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
860
861 // The result type should be at least as static as inferred result type.
862 if (preservesStaticInformation(inferredResultType,
863 concatOp.getResultType())) {
864 return failure();
865 }
866
867 auto newConcatOp =
868 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
869 concatOp->getOperands());
870 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
871 newConcatOp);
872
873 return success();
874 }
875};
876} // namespace
877
878void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
879 MLIRContext *context) {
880 results
881 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
882 context);
883}
884
885//===----------------------------------------------------------------------===//
886// DimOp
887//===----------------------------------------------------------------------===//
888
889void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890 setNameFn(getResult(), "dim");
891}
892
893void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894 int64_t index) {
895 auto loc = result.location;
896 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
897 build(builder, result, source, indexValue);
898}
899
900std::optional<int64_t> DimOp::getConstantIndex() {
902}
903
904Speculation::Speculatability DimOp::getSpeculatability() {
905 auto constantIndex = getConstantIndex();
906 if (!constantIndex)
908
909 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
910 if (!rankedSourceType)
912
913 if (rankedSourceType.getRank() <= constantIndex)
915
917}
918
919void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920 SetIntLatticeFn setResultRange) {
921 setResultRange(getResult(),
922 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923}
924
925OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926 // All forms of folding require a known index.
927 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928 if (!index)
929 return {};
930
931 // Folding for unranked types (UnrankedTensorType) is not supported.
932 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
933 if (!tensorType)
934 return {};
935
936 // Out of bound indices produce undefined behavior but are still valid IR.
937 // Don't choke on them.
938 int64_t indexVal = index.getInt();
939 if (indexVal < 0 || indexVal >= tensorType.getRank())
940 return {};
941
942 // Fold if the shape extent along the given index is known.
943 if (!tensorType.isDynamicDim(index.getInt())) {
944 Builder builder(getContext());
945 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
946 }
947
948 Operation *definingOp = getSource().getDefiningOp();
949
950 // Fold dim to the operand of tensor.generate.
951 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
952 auto resultType =
953 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
954 // The case where the type encodes the size of the dimension is handled
955 // above.
956 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
957
958 // Find the operand of the fromElements that corresponds to this index.
959 auto dynExtents = fromElements.getDynamicExtents().begin();
960 for (auto dim : resultType.getShape().take_front(index.getInt()))
961 if (ShapedType::isDynamic(dim))
962 dynExtents++;
963
964 return Value{*dynExtents};
965 }
966
967 // The size at the given index is now known to be a dynamic size.
968 unsigned unsignedIndex = index.getValue().getZExtValue();
969
970 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
971 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
972 // `resolve-shaped-type-result-dims` pass.
973 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
974 sliceOp.isDynamicSize(unsignedIndex)) {
975 return {sliceOp.getDynamicSize(unsignedIndex)};
976 }
977 }
978
979 // dim(cast) -> dim
980 if (succeeded(foldTensorCast(*this)))
981 return getResult();
982
983 return {};
984}
985
986namespace {
987/// Fold dim of a cast into the dim of the source of the tensor cast.
988struct DimOfCastOp : public OpRewritePattern<DimOp> {
989 using OpRewritePattern<DimOp>::OpRewritePattern;
990
991 LogicalResult matchAndRewrite(DimOp dimOp,
992 PatternRewriter &rewriter) const override {
993 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
994 if (!castOp)
995 return failure();
996 Value newSource = castOp.getOperand();
997 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
998 return success();
999 }
1000};
1001
1002/// Fold dim of a destination passing style op into the dim of the corresponding
1003/// init.
1004struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1005 using OpRewritePattern<DimOp>::OpRewritePattern;
1006
1007 LogicalResult matchAndRewrite(DimOp dimOp,
1008 PatternRewriter &rewriter) const override {
1009 auto source = dimOp.getSource();
1010 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1011 if (!destOp)
1012 return failure();
1013
1014 auto resultIndex = cast<OpResult>(source).getResultNumber();
1015 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1016
1017 rewriter.modifyOpInPlace(
1018 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1019 return success();
1020 }
1021};
1022
1023/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1024/// operand.
1025struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1026 using OpRewritePattern<DimOp>::OpRewritePattern;
1027
1028 LogicalResult matchAndRewrite(DimOp dim,
1029 PatternRewriter &rewriter) const override {
1030 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1031
1032 if (!reshape)
1033 return failure();
1034
1035 // Since tensors are immutable we don't need to worry about where to place
1036 // the extract call
1037 rewriter.setInsertionPointAfter(dim);
1038 Location loc = dim.getLoc();
1039 Value extract =
1040 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1041 if (extract.getType() != dim.getType())
1042 extract =
1043 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1044 rewriter.replaceOp(dim, extract);
1045 return success();
1046 }
1047};
1048} // namespace
1049
1050void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1051 MLIRContext *context) {
1052 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1053}
1054
1055//===----------------------------------------------------------------------===//
1056// EmptyOp
1057//===----------------------------------------------------------------------===//
1058
1059void EmptyOp::build(OpBuilder &builder, OperationState &result,
1060 ArrayRef<int64_t> staticShape, Type elementType,
1061 Attribute encoding) {
1062 assert(none_of(staticShape, ShapedType::isDynamic) &&
1063 "expected only static sizes");
1064 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1065}
1066
1067void EmptyOp::build(OpBuilder &builder, OperationState &result,
1068 ArrayRef<int64_t> staticShape, Type elementType,
1069 ValueRange dynamicSizes, Attribute encoding) {
1070 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1071 build(builder, result, tensorType, dynamicSizes);
1072}
1073
1074void EmptyOp::build(OpBuilder &builder, OperationState &result,
1075 ArrayRef<OpFoldResult> sizes, Type elementType,
1076 Attribute encoding) {
1077 SmallVector<int64_t> staticShape;
1078 SmallVector<Value> dynamicSizes;
1079 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1080 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1081}
1082
1083LogicalResult EmptyOp::verify() {
1084 return verifyDynamicDimensionCount(getOperation(), getType(),
1085 getDynamicSizes());
1086}
1087
1088LogicalResult
1089EmptyOp::reifyResultShapes(OpBuilder &builder,
1090 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1091 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1092 unsigned ctr = 0;
1093 for (int64_t i = 0; i < getType().getRank(); ++i) {
1094 if (getType().isDynamicDim(i)) {
1095 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1096 } else {
1097 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1098 }
1099 }
1100 return success();
1101}
1102
1103Value EmptyOp::getDynamicSize(unsigned idx) {
1104 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1105 unsigned ctr = 0;
1106 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1107 if (getType().isDynamicDim(i))
1108 ++ctr;
1109 return getDynamicSizes()[ctr];
1110}
1111
1112SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1113 SmallVector<OpFoldResult> result;
1114 unsigned ctr = 0;
1115 Builder b(getContext());
1116 for (int64_t dim : getType().getShape()) {
1117 if (ShapedType::isDynamic(dim)) {
1118 result.push_back(getDynamicSizes()[ctr++]);
1119 } else {
1120 result.push_back(b.getIndexAttr(dim));
1121 }
1122 }
1123 return result;
1124}
1125
1126namespace {
1127/// Change the type of the result of a `tensor.empty` by making the result
1128/// type statically sized along dimensions that in the original operation were
1129/// defined as dynamic, but the size was defined using a `constant` op. For
1130/// example
1131///
1132/// %c5 = arith.constant 5: index
1133/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1134///
1135/// to
1136///
1137/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1138struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1139 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1140
1141 LogicalResult matchAndRewrite(EmptyOp op,
1142 PatternRewriter &rewriter) const override {
1143 SmallVector<Value> foldedDynamicSizes;
1144 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1145 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1146
1147 // Stop here if no dynamic size was promoted to static.
1148 if (foldedTensorType == op.getType())
1149 return failure();
1150
1151 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1152 foldedDynamicSizes);
1153 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1154 return success();
1155 }
1156};
1157
1158struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1159 using OpRewritePattern<DimOp>::OpRewritePattern;
1160
1161 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1162 PatternRewriter &rewriter) const override {
1163 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1164 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1165 if (!emptyTensorOp || !maybeConstantIndex)
1166 return failure();
1167 auto emptyTensorType = emptyTensorOp.getType();
1168 if (*maybeConstantIndex < 0 ||
1169 *maybeConstantIndex >= emptyTensorType.getRank() ||
1170 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1171 return failure();
1172 rewriter.replaceOp(dimOp,
1173 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1174 return success();
1175 }
1176};
1177
1178/// Canonicalize
1179///
1180/// ```mlir
1181/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1182/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1183/// ```
1184///
1185/// into
1186///
1187/// ```mlir
1188/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1189/// ```
1190///
1191/// This assumes the input program is correct in terms of its shape. So it is
1192/// safe to assume that `%d0` is in fact 4.
1193struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1194 using OpRewritePattern<CastOp>::OpRewritePattern;
1195
1196 LogicalResult matchAndRewrite(CastOp castOp,
1197 PatternRewriter &rewriter) const override {
1198 if (!canFoldIntoProducerOp(castOp))
1199 return failure();
1200 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1201 if (!producer)
1202 return failure();
1203
1204 auto resultType =
1205 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1206 ArrayRef<int64_t> resultShape = resultType.getShape();
1207 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1208 SmallVector<OpFoldResult> newMixedSizes;
1209 newMixedSizes.reserve(currMixedSizes.size());
1210 assert(resultShape.size() == currMixedSizes.size() &&
1211 "mismatch in result shape and sizes of empty op");
1212 for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
1213 // Case 1: The empty tensor dim is static. Check that the tensor cast
1214 // result dim matches.
1215 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1216 if (ShapedType::isDynamic(newDim) ||
1217 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1218 // Something is off, the cast result shape cannot be more dynamic
1219 // than the empty tensor result shape (enforced by
1220 // `canFoldIntoProducer`). Abort for now.
1221 return rewriter.notifyMatchFailure(
1222 producer, "mismatch in static value of shape of empty tensor "
1223 "result and cast result");
1224 }
1225 newMixedSizes.push_back(attr);
1226 continue;
1227 }
1228
1229 // Case 2 : The tensor cast shape is static, but empty tensor result
1230 // shape is dynamic.
1231 if (ShapedType::isStatic(newDim)) {
1232 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1233 continue;
1234 }
1235
1236 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1237 // shape is dynamic. Use the dynamic value from the empty tensor op.
1238 newMixedSizes.push_back(currDim);
1239 }
1240
1241 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1242 resultType.getElementType(),
1243 resultType.getEncoding());
1244 return success();
1245 }
1246};
1247
1248} // namespace
1249
1250void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1251 MLIRContext *context) {
1252 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1253 ReplaceEmptyTensorStaticShapeDims>(context);
1254}
1255
1256//===----------------------------------------------------------------------===//
1257// ExtractOp
1258//===----------------------------------------------------------------------===//
1259
1260namespace {
1261
1262/// Canonicalizes the pattern of the form
1263///
1264/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1265/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1266///
1267/// to
1268///
1269/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1270struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1271 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1272
1273 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1274 PatternRewriter &rewriter) const final {
1275 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1276 if (!tensorCast)
1277 return failure();
1278 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1279 return failure();
1280 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1281 extract, tensorCast.getSource(), extract.getIndices());
1282 return success();
1283 }
1284};
1285
1286/// Canonicalizes the pattern of the form
1287///
1288/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1289/// tensor<12xf64>
1290/// %extracted_element = tensor.extract %val[%c10] :
1291/// tensor<12xf64>
1292///
1293/// to
1294///
1295/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1296struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1297 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1298
1299 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1300 PatternRewriter &rewriter) const final {
1301 auto collapseOp =
1302 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1303 if (!collapseOp)
1304 return failure();
1305 if (!collapseOp.getSrcType().hasStaticShape())
1306 return failure();
1307
1308 auto sourceSizes = collapseOp.getSrcType().getShape();
1309
1310 SmallVector<Value> indices(extractOp.getIndices().begin(),
1311 extractOp.getIndices().end());
1312 SmallVector<Value> sourceIndices;
1313 for (auto [index, group] :
1314 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1315 assert(!group.empty() && "association indices groups cannot be empty");
1316 auto groupSize = group.size();
1317
1318 if (groupSize == 1) {
1319 sourceIndices.push_back(index);
1320 continue;
1321 }
1322
1323 SmallVector<int64_t> basis =
1324 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1325 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1326 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1327 llvm::append_range(sourceIndices, delinearize.getResults());
1328 }
1329 if (collapseOp.getReassociationIndices().empty()) {
1330 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1331 int64_t srcRank =
1332 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1333 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1334 rewriter, extractOp.getLoc(), zeroAffineMap,
1335 ArrayRef<OpFoldResult>{});
1336 for (int64_t i = 0; i < srcRank; i++) {
1337 sourceIndices.push_back(
1338 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1339 }
1340 }
1341
1342 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1343 extractOp, collapseOp.getSrc(), sourceIndices);
1344 return success();
1345 }
1346};
1347
1348} // namespace
1349
1350void ExtractOp::getAsmResultNames(
1351 function_ref<void(Value, StringRef)> setNameFn) {
1352 setNameFn(getResult(), "extracted");
1353}
1354
1355LogicalResult ExtractOp::verify() {
1356 // Verify the # indices match if we have a ranked type.
1357 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1358 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1359 return emitOpError("incorrect number of indices for extract_element");
1360 return success();
1361}
1362
1363/// If we have an ExtractOp consuming an InsertOp with the same
1364/// indices, we can return the InsertOp's scalar directly.
1365// TODO: This only checks the immediate producer; extend to go up the
1366// insert/extract chain if the slices are disjoint.
1367static Value foldExtractAfterInsert(ExtractOp extractOp) {
1368 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1369
1370 auto isSame = [](Value a, Value b) {
1372 };
1373 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1374 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1375 return insertOp.getScalar();
1376
1377 return {};
1378}
1379
1380OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1381 if (Attribute tensor = adaptor.getTensor()) {
1382 // If this is a splat elements attribute, simply return the value.
1383 // All of the elements of a splat attribute are the same.
1384 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1385 return splatTensor.getSplatValue<Attribute>();
1386
1387 // If this is a dense resource elements attribute, return.
1388 if (isa<DenseResourceElementsAttr>(tensor))
1389 return {};
1390 }
1391
1392 // Collect the constant indices into the tensor.
1393 SmallVector<uint64_t, 8> indices;
1394 for (Attribute indice : adaptor.getIndices()) {
1395 if (!indice || !llvm::isa<IntegerAttr>(indice))
1396 return {};
1397 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1398 }
1399
1400 // Fold extract(from_elements(...)).
1401 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1402 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1403 auto rank = tensorType.getRank();
1404 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1405 "rank mismatch");
1406 int flatIndex = 0;
1407 int stride = 1;
1408 for (int i = rank - 1; i >= 0; --i) {
1409 flatIndex += indices[i] * stride;
1410 stride *= tensorType.getDimSize(i);
1411 }
1412 // Prevent out of bounds accesses. This can happen in invalid code that
1413 // will never execute.
1414 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1415 flatIndex < 0)
1416 return {};
1417 return fromElementsOp.getElements()[flatIndex];
1418 }
1419
1420 // If this is an elements attribute, query the value at the given indices.
1421 if (Attribute tensor = adaptor.getTensor()) {
1422 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1423 if (elementsAttr && elementsAttr.isValidIndex(indices))
1424 return elementsAttr.getValues<Attribute>()[indices];
1425 }
1426
1427 if (Value result = foldExtractAfterInsert(*this))
1428 return result;
1429
1430 return {};
1431}
1432
1433void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1434 MLIRContext *context) {
1435 results.add<ExtractFromTensorCast>(context);
1436}
1437
1439 RewritePatternSet &patterns) {
1440 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1441}
1442
1443//===----------------------------------------------------------------------===//
1444// FromElementsOp
1445//===----------------------------------------------------------------------===//
1446
1447void FromElementsOp::getAsmResultNames(
1448 function_ref<void(Value, StringRef)> setNameFn) {
1449 setNameFn(getResult(), "from_elements");
1450}
1451
1452void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1453 ValueRange elements) {
1454 assert(!elements.empty() && "expected at least one element");
1455 Type resultType = RankedTensorType::get(
1456 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1457 build(builder, result, resultType, elements);
1458}
1459
1460OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1461 // DenseElementsAttr::get requires StringAttr for element types that are not
1462 // integer, index, float, or complex (e.g. vector types), but folded constants
1463 // won't be StringAttr instances. Only fold for element types directly
1464 // supported by DenseElementsAttr.
1465 Type eltType = getType().getElementType();
1466 if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
1467 return {};
1468 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1469 return DenseElementsAttr::get(getType(), adaptor.getElements());
1470 return {};
1471}
1472
1473namespace {
1474
1475// Pushes the index_casts that occur before extractions to after the extract.
1476// This minimizes type conversion in some cases and enables the extract
1477// canonicalizer. This changes:
1478//
1479// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1480// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1481//
1482// to the following:
1483//
1484// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1485// %cast = arith.index_cast %extract : i32 to index
1486//
1487// to just %element.
1488//
1489// Consider expanding this to a template and handle all tensor cast
1490// operations.
1491struct ExtractElementFromIndexCast
1492 : public OpRewritePattern<tensor::ExtractOp> {
1493 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1494
1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1496 PatternRewriter &rewriter) const final {
1497 Location loc = extract.getLoc();
1498 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1499 if (!indexCast)
1500 return failure();
1501
1502 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1503
1504 auto newExtract = tensor::ExtractOp::create(
1505 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1506
1507 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1508 newExtract);
1509
1510 return success();
1511 }
1512};
1513
1514} // namespace
1515
1516void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517 MLIRContext *context) {
1518 results.add<ExtractElementFromIndexCast>(context);
1519}
1520
1521//===----------------------------------------------------------------------===//
1522// GatherOp
1523//===----------------------------------------------------------------------===//
1524
1525void GatherOp::getAsmResultNames(
1526 function_ref<void(Value, StringRef)> setNameFn) {
1527 setNameFn(getResult(), "gather");
1528}
1529
1530/// Return the inferred result type for a gatherOp where:
1531/// - sourceType is the type of the source tensor gathered from
1532/// - indicesType is the type of the indices used to gather
1533/// - gatherDims are the dims along which the gather occurs.
1534/// Return a full rank or ranked-reduced variant of the type depending on
1535/// the value of rankReduced.
1536///
1537/// The leading dimensions of the index tensor give the result tensor its
1538/// leading dimensions.
1539/// The trailing dimensions of the result tensor are obtained from the source
1540/// tensor by setting the dimensions specified in gather_dims to `1` (if
1541/// rankedReduced is false), or skipping them (otherwise).
1542RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543 RankedTensorType indicesType,
1544 ArrayRef<int64_t> gatherDims,
1545 bool rankReduced) {
1546 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1547 resultShape.reserve(resultShape.size() + sourceType.getRank());
1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549 if (llvm::binary_search(gatherDims, idx)) {
1550 if (!rankReduced)
1551 resultShape.push_back(1);
1552 continue;
1553 }
1554 resultShape.push_back(sourceType.getDimSize(idx));
1555 }
1556 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1557}
1558
1559static LogicalResult
1562 StringRef gatherOrScatter, StringRef sourceOrDest) {
1563 if (dims.empty())
1564 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1565
1566 int64_t numGatherDims = dims.size();
1567 if (numGatherDims > rank)
1568 return op->emitOpError(gatherOrScatter)
1569 << "_dims overflow " << sourceOrDest << " rank";
1570 if (indices.empty() || indices.back() != numGatherDims)
1571 return op->emitOpError(gatherOrScatter)
1572 << "_dims length must match the size of last dimension of indices";
1573 for (int64_t val : dims) {
1574 if (val < 0)
1575 return op->emitOpError(gatherOrScatter)
1576 << "_dims value must be non-negative";
1577 if (val >= rank)
1578 return op->emitOpError(gatherOrScatter)
1579 << "_dims value must be smaller than " << sourceOrDest << " rank";
1580 }
1581 for (int64_t i = 1; i < numGatherDims; ++i) {
1582 if (dims[i - 1] >= dims[i])
1583 return op->emitOpError(gatherOrScatter)
1584 << "_dims values must be strictly increasing";
1585 }
1586 return success();
1587}
1588
1589LogicalResult GatherOp::verify() {
1590 int64_t sourceRank = getSourceType().getRank();
1591 ArrayRef<int64_t> gatherDims = getGatherDims();
1592 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1593 getIndicesType().getShape(), sourceRank,
1594 "gather", "source")))
1595 return failure();
1596
1597 RankedTensorType expectedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1601 if (getResultType() != expectedResultType &&
1602 getResultType() != expectedRankReducedResultType) {
1603 return emitOpError("result type "
1604 "mismatch: "
1605 "expected ")
1606 << expectedResultType << " or its rank-reduced variant "
1607 << expectedRankReducedResultType << " (got: " << getResultType()
1608 << ")";
1609 }
1610
1611 return success();
1612}
1613
1614OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1615 if (OpFoldResult reshapedSource = reshapeConstantSource(
1616 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1617 getResult().getType()))
1618 return reshapedSource;
1619 return {};
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// InsertOp
1624//===----------------------------------------------------------------------===//
1625
1626void InsertOp::getAsmResultNames(
1627 function_ref<void(Value, StringRef)> setNameFn) {
1628 setNameFn(getResult(), "inserted");
1629}
1630
1631LogicalResult InsertOp::verify() {
1632 // Verify the # indices match if we have a ranked type.
1633 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1634 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1635 return emitOpError("incorrect number of indices");
1636 return success();
1637}
1638
1639OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1640 Attribute scalar = adaptor.getScalar();
1641 Attribute dest = adaptor.getDest();
1642 if (scalar && dest)
1643 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1644 if (scalar == splatDest.getSplatValue<Attribute>())
1645 return dest;
1646 return {};
1647}
1648
1649//===----------------------------------------------------------------------===//
1650// GenerateOp
1651//===----------------------------------------------------------------------===//
1652
1653void GenerateOp::getAsmResultNames(
1654 function_ref<void(Value, StringRef)> setNameFn) {
1655 setNameFn(getResult(), "generated");
1656}
1657
1658LogicalResult GenerateOp::reifyResultShapes(
1659 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1660 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1661 int idx = 0;
1662 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1663 if (getType().isDynamicDim(dim)) {
1664 reifiedReturnShapes[0][dim] = getOperand(idx++);
1665 } else {
1666 reifiedReturnShapes[0][dim] =
1667 builder.getIndexAttr(getType().getDimSize(dim));
1668 }
1669 }
1670 return success();
1671}
1672
1673LogicalResult GenerateOp::verify() {
1674 // Ensure that the tensor type has as many dynamic dimensions as are
1675 // specified by the operands.
1676 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1677 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1678 getOperands())))
1679 return failure();
1680 return success();
1681}
1682
1683LogicalResult GenerateOp::verifyRegions() {
1684 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1685 // Ensure that region arguments span the index space.
1686 if (!llvm::all_of(getBody().getArgumentTypes(),
1687 [](Type ty) { return ty.isIndex(); }))
1688 return emitError("all body arguments must be index");
1689 if (getBody().getNumArguments() != resultTy.getRank())
1690 return emitError("must have one body argument per input dimension");
1691
1692 // Ensure that the region yields an element of the right type.
1693 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1694
1695 if (yieldOp.getValue().getType() != resultTy.getElementType())
1696 return emitOpError(
1697 "body must be terminated with a `yield` operation of the tensor "
1698 "element type");
1699
1700 return success();
1701}
1702
1703void GenerateOp::build(
1704 OpBuilder &b, OperationState &result, Type resultTy,
1705 ValueRange dynamicExtents,
1706 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1707 build(b, result, resultTy, dynamicExtents);
1708
1709 // Build and populate body.
1710 OpBuilder::InsertionGuard guard(b);
1711 Region *bodyRegion = result.regions.front().get();
1712 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1713 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1714 SmallVector<Location, 2> argumentLocs(rank, result.location);
1715 Block *bodyBlock =
1716 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1717 bodyBuilder(b, result.location, bodyBlock->getArguments());
1718}
1719
1720namespace {
1721
1722/// Canonicalizes tensor.generate operations with a constant
1723/// operand into the equivalent operation with the operand expressed in the
1724/// result type, instead. We also insert a type cast to make sure that the
1725/// resulting IR is still well-typed.
1726struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1727 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1728
1729 LogicalResult matchAndRewrite(GenerateOp generateOp,
1730 PatternRewriter &rewriter) const final {
1731 SmallVector<Value> foldedDynamicSizes;
1732 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1733 generateOp.getType(), generateOp.getDynamicExtents(),
1734 foldedDynamicSizes);
1735
1736 // Stop here if no dynamic size was promoted to static.
1737 if (foldedTensorType == generateOp.getType())
1738 return failure();
1739
1740 auto loc = generateOp.getLoc();
1741 auto newOp =
1742 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1743 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1744 newOp.getBody().begin());
1745 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1746 generateOp.getType(), newOp);
1747 return success();
1748 }
1749};
1750
1751/// Canonicalizes the pattern of the form
1752///
1753/// %tensor = tensor.generate %x {
1754/// ^bb0(%arg0: index):
1755/// <computation>
1756/// yield %1 : index
1757/// } : tensor<?xindex>
1758/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1759///
1760/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1761/// tensor.generate operation has no side-effects.
1762struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1763 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1764
1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1766 PatternRewriter &rewriter) const final {
1767 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1768 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1769 return failure();
1770
1771 IRMapping mapping;
1772 Block *body = &tensorFromElements.getBody().front();
1773 mapping.map(body->getArguments(), extract.getIndices());
1774 for (auto &op : body->without_terminator())
1775 rewriter.clone(op, mapping);
1776
1777 auto yield = cast<YieldOp>(body->getTerminator());
1778
1779 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1780 return success();
1781 }
1782};
1783
1784} // namespace
1785
1786void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787 MLIRContext *context) {
1788 // TODO: Move extract pattern to tensor::ExtractOp.
1789 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// RankOp
1794//===----------------------------------------------------------------------===//
1795
1796void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1797 setNameFn(getResult(), "rank");
1798}
1799
1800OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1801 // Constant fold rank when the rank of the operand is known.
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804 if (shapedType && shapedType.hasRank())
1805 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1806 return IntegerAttr();
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// ReshapeOp
1811//===----------------------------------------------------------------------===//
1812
1813void ReshapeOp::getAsmResultNames(
1814 function_ref<void(Value, StringRef)> setNameFn) {
1815 setNameFn(getResult(), "reshape");
1816}
1817
1818static int64_t getNumElements(ShapedType type) {
1819 int64_t numElements = 1;
1820 for (auto dim : type.getShape())
1821 numElements *= dim;
1822 return numElements;
1823}
1824
1825LogicalResult ReshapeOp::verify() {
1826 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1827 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1828
1829 if (operandType.getElementType() != resultType.getElementType())
1830 return emitOpError("element types of source and destination tensor "
1831 "types should be the same");
1832
1833 int64_t shapeSize =
1834 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1835 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1836 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1837
1838 if (resultRankedType) {
1839 if (operandRankedType && resultRankedType.hasStaticShape() &&
1840 operandRankedType.hasStaticShape()) {
1841 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1842 return emitOpError("source and destination tensor should have the "
1843 "same number of elements");
1844 }
1845 if (ShapedType::isDynamic(shapeSize))
1846 return emitOpError("cannot use shape operand with dynamic length to "
1847 "reshape to statically-ranked tensor type");
1848 if (shapeSize != resultRankedType.getRank())
1849 return emitOpError(
1850 "length of shape operand differs from the result's tensor rank");
1851 }
1852 return success();
1853}
1854
1855OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1856 if (OpFoldResult reshapedSource = reshapeConstantSource(
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1858 getResult().getType()))
1859 return reshapedSource;
1860
1861 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1862 // producer's input instead as the original tensor to reshape. This could
1863 // render such producer dead code.
1864 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1865 getSourceMutable().assign(reshapeOpProducer.getSource());
1866 return getResult();
1867 }
1868
1869 auto source = getSource();
1870 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1871 auto resultTy = dyn_cast<RankedTensorType>(getType());
1872 if (!sourceTy || !resultTy || sourceTy != resultTy)
1873 return {};
1874
1875 // If the source and result are both 0D or 1D tensors and have the same type,
1876 // the reshape has no effect, even if the tensor is dynamically shaped.
1877 if (sourceTy.getRank() <= 1)
1878 return source;
1879
1880 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1881 auto elements = fromElements.getElements();
1882 bool dynamicNoop =
1883 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1884 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1885 auto element = elements[id];
1886
1887 if (auto cst = getConstantIntValue(element)) {
1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1889 continue;
1890 }
1891
1892 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1893 dynamicNoop &= dimOp.getSource() == source;
1894
1895 auto cst = getConstantIntValue(dimOp.getIndex());
1896 dynamicNoop &=
1897 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1898 continue;
1899 }
1900
1901 dynamicNoop = false;
1902 break;
1903 }
1904
1905 if (dynamicNoop)
1906 return source;
1907 }
1908
1909 return {};
1910}
1911
1912//===----------------------------------------------------------------------===//
1913// Reassociative reshape ops
1914//===----------------------------------------------------------------------===//
1915
1916void CollapseShapeOp::getAsmResultNames(
1917 function_ref<void(Value, StringRef)> setNameFn) {
1918 setNameFn(getResult(), "collapsed");
1919}
1920
1921void ExpandShapeOp::getAsmResultNames(
1922 function_ref<void(Value, StringRef)> setNameFn) {
1923 setNameFn(getResult(), "expanded");
1924}
1925
1926int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928 "invalid resultDim");
1929 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1930 if (llvm::is_contained(it.value(), resultDim))
1931 return it.index();
1932 llvm_unreachable("could not find reassociation group");
1933}
1934
1935FailureOr<SmallVector<OpFoldResult>>
1936ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1937 RankedTensorType expandedType,
1938 ArrayRef<ReassociationIndices> reassociation,
1939 ArrayRef<OpFoldResult> inputShape) {
1940 std::optional<SmallVector<OpFoldResult>> outputShape =
1941 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1942 inputShape);
1943 if (!outputShape)
1944 return failure();
1945 return *outputShape;
1946}
1947
1948SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1949 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1950}
1951
1952void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1953 Type resultType, Value src,
1954 ArrayRef<ReassociationIndices> reassociation,
1955 ArrayRef<OpFoldResult> outputShape) {
1956 auto [staticOutputShape, dynamicOutputShape] =
1957 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1958 build(builder, result, cast<RankedTensorType>(resultType), src,
1959 getReassociationIndicesAttribute(builder, reassociation),
1960 dynamicOutputShape, staticOutputShape);
1961}
1962
1963void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1964 Type resultType, Value src,
1965 ArrayRef<ReassociationIndices> reassociation) {
1966 SmallVector<OpFoldResult> inputShape =
1967 getMixedSizes(builder, result.location, src);
1968 auto tensorResultTy = cast<RankedTensorType>(resultType);
1969 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1970 builder, result.location, tensorResultTy, reassociation, inputShape);
1971 SmallVector<OpFoldResult> outputShapeOrEmpty;
1972 if (succeeded(outputShape)) {
1973 outputShapeOrEmpty = *outputShape;
1974 }
1975 build(builder, result, tensorResultTy, src, reassociation,
1976 outputShapeOrEmpty);
1977}
1978
1979SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1980 return getSymbolLessAffineMaps(getReassociationExprs());
1981}
1982SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1985}
1986
1987SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1988 return getSymbolLessAffineMaps(getReassociationExprs());
1989}
1990SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1992 getReassociationIndices());
1993}
1994
1995RankedTensorType CollapseShapeOp::inferCollapsedType(
1996 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1997 return inferCollapsedType(
1999 type.getContext(), reassociation)));
2000}
2001
2002/// Compute the RankedTensorType obtained by applying `reassociation` to
2003/// `type`.
2004RankedTensorType
2005CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2006 ArrayRef<AffineMap> reassociation) {
2007 auto shape = type.getShape();
2008 SmallVector<int64_t, 4> newShape;
2009 newShape.reserve(reassociation.size());
2010
2011 // Use the fact that reassociation is valid to simplify the logic: only use
2012 // each map's rank.
2013 assert(isReassociationValid(reassociation) && "invalid reassociation");
2014 unsigned currentDim = 0;
2015 for (AffineMap m : reassociation) {
2016 unsigned dim = m.getNumResults();
2017 auto band = shape.slice(currentDim, dim);
2018 int64_t size = 1;
2019 if (llvm::is_contained(band, ShapedType::kDynamic))
2020 size = ShapedType::kDynamic;
2021 else
2022 for (unsigned d = 0; d < dim; ++d)
2023 size *= shape[currentDim + d];
2024 newShape.push_back(size);
2025 currentDim += dim;
2026 }
2027
2028 return RankedTensorType::get(newShape, type.getElementType());
2029}
2030
2031void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2032 ArrayRef<ReassociationIndices> reassociation,
2033 ArrayRef<NamedAttribute> attrs) {
2034 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2035 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2036 auto resultType =
2037 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2038 srcType.getEncoding());
2039 result.addAttribute(getReassociationAttrStrName(),
2040 getReassociationIndicesAttribute(b, reassociation));
2041 build(b, result, resultType, src, attrs);
2042}
2043
2044template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2045 TensorReshapeOp, ExpandShapeOp>::value>
2046static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2047 RankedTensorType expandedType,
2048 RankedTensorType collapsedType) {
2049 if (failed(
2050 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2051 return failure();
2052
2053 // Reshape must preserve the number of elements when statically known.
2054 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2055 int64_t expandedNumElements = expandedType.getNumElements();
2056 int64_t collapsedNumElements = collapsedType.getNumElements();
2057 if (expandedNumElements != collapsedNumElements) {
2058 return op.emitOpError("number of elements must be preserved: ")
2059 << expandedNumElements << " != " << collapsedNumElements;
2060 }
2061 }
2062
2063 auto maps = op.getReassociationMaps();
2064 RankedTensorType expectedType =
2065 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2066 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2067 return op.emitOpError("expected collapsed type to be ")
2068 << expectedType << ", but got " << collapsedType;
2069 return success();
2070}
2071
2072LogicalResult ExpandShapeOp::verify() {
2073 RankedTensorType srcType = getSrc().getType();
2074 RankedTensorType resultType = getResult().getType();
2075
2076 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2077 return emitOpError("expected number of static shape dims to be equal to "
2078 "the output rank (")
2079 << resultType.getRank() << ") but found "
2080 << getStaticOutputShape().size() << " inputs instead";
2081
2082 if ((int64_t)getOutputShape().size() !=
2083 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2084 return emitOpError("mismatch in dynamic dims in output_shape and "
2085 "static_output_shape: static_output_shape has ")
2086 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2087 << " dynamic dims while output_shape has " << getOutputShape().size()
2088 << " values";
2089
2090 // Verify that the number of dynamic dims in output_shape matches the number
2091 // of dynamic dims in the result type.
2092 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2093 getOutputShape())))
2094 return failure();
2095
2096 // Verify if provided output shapes are in agreement with output type.
2097 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2098 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2099 for (auto [pos, shape] : llvm::enumerate(resShape))
2100 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2101 return emitOpError("invalid output shape provided at pos ") << pos;
2102
2103 return verifyTensorReshapeOp(*this, resultType, srcType);
2104}
2105
2106LogicalResult CollapseShapeOp::verify() {
2107 CollapseShapeOp op = *this;
2108 if (llvm::any_of(op.getReassociationIndices(),
2109 [](ReassociationIndices group) { return group.empty(); })) {
2110 return op.emitOpError("reassociation indices must not be empty");
2111 }
2112 RankedTensorType srcType = op.getSrc().getType();
2113 RankedTensorType resultType = op.getResult().getType();
2114
2115 return verifyTensorReshapeOp(op, srcType, resultType);
2116}
2117
2118namespace {
2119/// Reshape of a splat constant can be replaced with a constant of the result
2120/// type.
2121template <typename TensorReshapeOp>
2122struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2123 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2124 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2125 PatternRewriter &rewriter) const override {
2126 DenseElementsAttr attr;
2127 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2128 return failure();
2129 if (!attr || !attr.isSplat())
2130 return failure();
2131 // DenseElementsAttr requires a static shape; skip folding for dynamic
2132 // result types.
2133 if (!reshapeOp.getResultType().hasStaticShape())
2134 return failure();
2135 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2136 reshapeOp.getResultType(), attr.getRawData());
2137 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2138 return success();
2139 }
2140};
2141
2142// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2143template <typename TensorReshapeOp>
2144class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2145public:
2146 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2147
2148 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2149 PatternRewriter &rewriter) const override {
2150 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2151 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2152 return failure();
2153
2154 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2155 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2156 return success();
2157 }
2158};
2159
2160/// Reshape of a FromElements can be replaced with a FromElements of the
2161/// result type
2162template <typename TensorReshapeOp>
2163struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2164 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2165 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2166 PatternRewriter &rewriter) const override {
2167 auto fromElements =
2168 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2169 if (!fromElements)
2170 return failure();
2171
2172 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2173
2174 if (!shapedTy.hasStaticShape())
2175 return failure();
2176
2177 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2178 fromElements.getElements());
2179 return success();
2180 }
2181};
2182
2183// Fold CastOp into CollapseShapeOp when adding static information.
2184struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2185 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2186
2187 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2188 PatternRewriter &rewriter) const override {
2189 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2190 if (!tensor::canFoldIntoConsumerOp(castOp))
2191 return failure();
2192
2193 RankedTensorType srcType =
2194 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2195 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2196 srcType, collapseShapeOp.getReassociationMaps());
2197
2198 if (newResultType == collapseShapeOp.getResultType()) {
2199 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2200 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2201 });
2202 } else {
2203 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2204 newResultType, castOp.getSource(),
2205 collapseShapeOp.getReassociation());
2206 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2207 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2208 }
2209 return success();
2210 }
2211};
2212
2213/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2214/// matching constant output_shape operands of the expand. This makes the
2215/// `tensor.expand_shape` more static and creates a consumer cast that can be
2216/// propagated further.
2217struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2218 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2219
2220 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2221 PatternRewriter &rewriter) const override {
2222 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2223 if (!canFoldIntoConsumerOp(castOp))
2224 return failure();
2225
2226 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2227 SmallVector<ReassociationIndices, 4> reassoc =
2228 expandOp.getReassociationIndices();
2229
2230 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2231 SmallVector<Value> dynamicOutputShape;
2232 auto outputIt = expandOp.getOutputShape().begin();
2233
2234 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2235 for (uint64_t outDim : innerReassoc) {
2236 if (ShapedType::isStatic(newOutputShape[outDim]))
2237 continue;
2238
2239 // If the cast's src type is dynamic, don't infer any of the
2240 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2241 // least one of the expanded dimensions to be dynamic if the input is
2242 // dynamic.
2243 Value val = *outputIt;
2244 ++outputIt;
2245 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2246 dynamicOutputShape.push_back(val);
2247 continue;
2248 }
2249
2250 APInt cst;
2251 if (matchPattern(val, m_ConstantInt(&cst))) {
2252 newOutputShape[outDim] = cst.getSExtValue();
2253 } else {
2254 dynamicOutputShape.push_back(val);
2255 }
2256 }
2257 }
2258
2259 // Couldn't match any values, nothing to change
2260 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2261 return failure();
2262
2263 // Calculate the input shape from the output
2264 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2265 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2266 for (auto outDim : reassoc[inDim]) {
2267 auto ofr = newOutputShape[outDim];
2268 if (ShapedType::isDynamic(ofr)) {
2269 newInputShape[inDim] = ShapedType::kDynamic;
2270 break;
2271 }
2272 newInputShape[inDim] *= ofr;
2273 }
2274 }
2275
2276 SmallVector<OpFoldResult> outputOfr =
2277 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2278 auto inputType = RankedTensorType::get(
2279 newInputShape, expandOp.getSrcType().getElementType());
2280 auto outputType = RankedTensorType::get(
2281 newOutputShape, expandOp.getSrcType().getElementType());
2282 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2283 expandOp.getSrc());
2284 auto newExpand = ExpandShapeOp::create(
2285 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2286 expandOp.getReassociationIndices(), outputOfr);
2287 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2288 newExpand.getResult());
2289 return success();
2290 }
2291};
2292} // namespace
2293
2294void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2295 MLIRContext *context) {
2296 results.add<
2297 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2298 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2299 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2300 FoldReshapeWithSplat<ExpandShapeOp>,
2301 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2302}
2303
2304void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2305 MLIRContext *context) {
2306 results.add<
2307 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2308 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2309 tensor::DimOp, RankedTensorType>,
2310 FoldReshapeWithConstant<CollapseShapeOp>,
2311 FoldReshapeWithSplat<CollapseShapeOp>,
2312 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2313 context);
2314}
2315
2316OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2318 adaptor.getOperands());
2319}
2320
2321OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2323 adaptor.getOperands());
2324}
2325
2326//===----------------------------------------------------------------------===//
2327// ExtractSliceOp
2328//===----------------------------------------------------------------------===//
2329
2330void ExtractSliceOp::getAsmResultNames(
2331 function_ref<void(Value, StringRef)> setNameFn) {
2332 setNameFn(getResult(), "extracted_slice");
2333}
2334
2335/// An extract_slice result type can be inferred, when it is not
2336/// rank-reduced, from the source type and the static representation of
2337/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2338RankedTensorType
2339ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2340 ArrayRef<int64_t> staticSizes) {
2341 // An extract_slice op may specify only a leading subset of offset/sizes/
2342 // strides in which case we complete with offset=0, sizes from memref type
2343 // and strides=1.
2344 assert(static_cast<int64_t>(staticSizes.size()) ==
2345 sourceTensorType.getRank() &&
2346 "unexpected staticSizes not equal to rank of source");
2347 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2348 sourceTensorType.getEncoding());
2349}
2350
2351// TODO: This uses neither offsets nor strides!
2352RankedTensorType
2353ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2354 ArrayRef<OpFoldResult> sizes) {
2355 SmallVector<int64_t> staticSizes;
2356 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2357
2358 assert(static_cast<int64_t>(staticSizes.size()) ==
2359 sourceTensorType.getRank() &&
2360 "unexpected staticSizes not equal to rank of source");
2361 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2362 sourceTensorType.getEncoding());
2363}
2364
2365/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2366/// number of sizes), drop as many size 1 as needed to produce an inferred
2367/// type with the desired rank.
2368///
2369/// Note that there may be multiple ways to compute this rank-reduced type:
2370/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2371///
2372/// To disambiguate, this function always drops the first 1 sizes occurrences.
2373RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2374 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2375 ArrayRef<int64_t> sizes) {
2376 // Type inferred in the absence of rank-reducing behavior.
2377 auto inferredType = llvm::cast<RankedTensorType>(
2378 inferResultType(sourceRankedTensorType, sizes));
2379 int rankDiff = inferredType.getRank() - desiredResultRank;
2380 if (rankDiff > 0) {
2381 auto shape = inferredType.getShape();
2382 llvm::SmallBitVector dimsToProject =
2383 getPositionsOfShapeOne(rankDiff, shape);
2384 SmallVector<int64_t> projectedShape;
2385 // Best effort rank-reducing: drop 1s in order.
2386 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2387 if (!dimsToProject.test(pos))
2388 projectedShape.push_back(shape[pos]);
2389 inferredType =
2390 RankedTensorType::get(projectedShape, inferredType.getElementType());
2391 }
2392 return inferredType;
2393}
2394
2395RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2396 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2397 ArrayRef<OpFoldResult> sizes) {
2398 SmallVector<int64_t> staticSizes;
2399 SmallVector<Value> dynamicSizes;
2400 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2401 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2402 desiredResultRank, sourceRankedTensorType, staticSizes);
2403}
2404
2405/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2406/// result type. If the type passed is nullptr, it is inferred.
2407void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2408 RankedTensorType resultType, Value source,
2409 ArrayRef<OpFoldResult> offsets,
2410 ArrayRef<OpFoldResult> sizes,
2411 ArrayRef<OpFoldResult> strides,
2412 ArrayRef<NamedAttribute> attrs) {
2413 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2414 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2415 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2416 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2417 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2418 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2419 // Structuring implementation this way avoids duplication between builders.
2420 if (!resultType) {
2421 resultType = llvm::cast<RankedTensorType>(
2422 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2423 }
2424 result.addAttributes(attrs);
2425 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2426 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2427 b.getDenseI64ArrayAttr(staticSizes),
2428 b.getDenseI64ArrayAttr(staticStrides));
2429}
2430
2431/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2432/// result type.
2433void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2434 ArrayRef<OpFoldResult> offsets,
2435 ArrayRef<OpFoldResult> sizes,
2436 ArrayRef<OpFoldResult> strides,
2437 ArrayRef<NamedAttribute> attrs) {
2438 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2439}
2440
2441/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2442/// a Range vector.
2443void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2444 ArrayRef<Range> ranges,
2445 ArrayRef<NamedAttribute> attrs) {
2446 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2447 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2448}
2449
2450/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2451/// the type passed is nullptr, it is inferred.
2452void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2453 RankedTensorType resultType, Value source,
2454 ValueRange offsets, ValueRange sizes,
2455 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2456 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2457 offsets, [](Value v) -> OpFoldResult { return v; });
2458 SmallVector<OpFoldResult> sizeValues =
2459 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2460 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2461 strides, [](Value v) -> OpFoldResult { return v; });
2462 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2463}
2464
2465/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2466void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2467 ValueRange offsets, ValueRange sizes,
2468 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2469 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2470}
2471
2473 Operation *op,
2474 RankedTensorType expectedType) {
2475 switch (result) {
2477 return success();
2479 return op->emitError("expected rank to be smaller or equal to ")
2480 << "the other rank. ";
2482 return op->emitError("expected type to be ")
2483 << expectedType << " or a rank-reduced version. (size mismatch) ";
2485 return op->emitError("expected element type to be ")
2486 << expectedType.getElementType();
2487 default:
2488 llvm_unreachable("unexpected extract_slice op verification result");
2489 }
2490}
2491
2492/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2493/// result type, offsets set to 0 and strides set to 1.
2494void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2495 RankedTensorType resultType, Value source,
2496 ArrayRef<OpFoldResult> sizes,
2497 ArrayRef<NamedAttribute> attrs) {
2498 Attribute zeroIdxAttr = b.getIndexAttr(0);
2499 Attribute oneIdxAttr = b.getIndexAttr(1);
2500 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2501 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2502 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2503}
2504
2505/// Verifier for ExtractSliceOp.
2506LogicalResult ExtractSliceOp::verify() {
2507 RankedTensorType sourceType = getSourceType();
2508
2509 // Verify result type against inferred type.
2510 RankedTensorType expectedType =
2511 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2514 return produceSliceErrorMsg(result, *this, expectedType);
2515
2516 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2517 // to the source tensor.
2518 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2519 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2520 getStaticStrides(), /*generateErrorMessage=*/true);
2521 if (!boundsResult.isValid)
2522 return getOperation()->emitError(boundsResult.errorMessage);
2523
2524 return success();
2525}
2526
2527llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2528 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2529}
2530
2531FailureOr<Value>
2532ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2533 ArrayRef<int64_t> desiredShape) {
2534 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2535 assert(sourceTensorType && "not a ranked tensor type");
2536 auto sourceShape = sourceTensorType.getShape();
2537 if (sourceShape.equals(desiredShape))
2538 return value;
2539 auto maybeRankReductionMask =
2540 mlir::computeRankReductionMask(sourceShape, desiredShape);
2541 if (!maybeRankReductionMask)
2542 return failure();
2544 b, loc, value,
2545 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2546}
2547
2548LogicalResult ExtractSliceOp::reifyResultShapes(
2549 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2550 reifiedReturnShapes.resize(1);
2551 reifiedReturnShapes[0].reserve(getType().getRank());
2552 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2553 llvm::SmallBitVector droppedDims = getDroppedDims();
2554 for (const auto &size : enumerate(mixedSizes)) {
2555 if (droppedDims.test(size.index()))
2556 continue;
2557 reifiedReturnShapes[0].push_back(size.value());
2558 }
2559 return success();
2560}
2561
2562namespace {
2563/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2564/// This essentially pushes memref_cast past its consuming slice when
2565/// `canFoldIntoConsumerOp` is true.
2566///
2567/// Example:
2568/// ```
2569/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2570/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2571/// tensor<3x4xf32>
2572/// ```
2573/// is rewritten into:
2574/// ```
2575/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2576/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2577/// ```
2578class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2579public:
2580 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2581
2582 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2583 PatternRewriter &rewriter) const override {
2584 // Any constant operand, just return to let the constant folder kick in.
2585 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2586 return matchPattern(operand, matchConstantIndex());
2587 }))
2588 return failure();
2589
2590 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2591 if (!castOp)
2592 return failure();
2593
2594 if (!canFoldIntoConsumerOp(castOp))
2595 return failure();
2596
2597 // Pattern does not apply if the produced op would not verify.
2598 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2599 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2600 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2601 sliceOp.getStaticStrides());
2602 if (!sliceResult.isValid)
2603 return failure();
2604
2605 // Create folded extract.
2606 Location loc = sliceOp.getLoc();
2607 Value newResult = ExtractSliceOp::create(
2608 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2609 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2610 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2611 sliceOp.getStaticStrides());
2612 rewriter.replaceOp(sliceOp, newResult);
2613 return success();
2614 }
2615};
2616
2617/// Slice elements from `values` into `outValues`. `counts` represents the
2618/// numbers of elements to stride in the original values for each dimension.
2619/// The output values can be used to construct a DenseElementsAttr.
2620template <typename IterTy, typename ElemTy>
2621static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2622 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2623 ArrayRef<int64_t> strides,
2624 llvm::SmallVectorImpl<ElemTy> *outValues) {
2625 assert(offsets.size() == sizes.size());
2626 assert(offsets.size() == strides.size());
2627 if (offsets.empty())
2628 return;
2629
2630 int64_t offset = offsets.front();
2631 int64_t size = sizes.front();
2632 int64_t stride = strides.front();
2633 if (offsets.size() == 1) {
2634 for (int64_t i = 0; i < size; ++i, offset += stride)
2635 outValues->push_back(*(values + offset));
2636
2637 return;
2638 }
2639
2640 for (int64_t i = 0; i < size; ++i, offset += stride) {
2641 auto begin = values + offset * counts.front();
2642 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2643 offsets.drop_front(), sizes.drop_front(),
2644 strides.drop_front(), outValues);
2645 }
2646}
2647
2648/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2649/// folded operation might introduce more constant data; Users can control
2650/// their heuristics by the control function.
2651class ConstantOpExtractSliceFolder final
2652 : public OpRewritePattern<ExtractSliceOp> {
2653public:
2654 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2655
2656 ConstantOpExtractSliceFolder(MLIRContext *context,
2658 : OpRewritePattern<ExtractSliceOp>(context),
2659 controlFn(std::move(controlFn)) {}
2660
2661 LogicalResult matchAndRewrite(ExtractSliceOp op,
2662 PatternRewriter &rewriter) const override {
2663 DenseElementsAttr attr;
2664 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2665 return failure();
2666
2667 // A constant splat is handled by fold().
2668 if (attr.isSplat())
2669 return failure();
2670
2671 // Dynamic result shape is not supported.
2672 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2673 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2674 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2675 return failure();
2676
2677 // Customized control over the folding.
2678 if (!controlFn(op))
2679 return failure();
2680
2681 int64_t count = sourceType.getNumElements();
2682 if (count == 0)
2683 return failure();
2684
2685 // Check if there are any dynamic parts, which are not supported.
2686 auto offsets = op.getStaticOffsets();
2687 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2688 return failure();
2689 auto sizes = op.getStaticSizes();
2690 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2691 return failure();
2692 auto strides = op.getStaticStrides();
2693 if (llvm::is_contained(strides, ShapedType::kDynamic))
2694 return failure();
2695
2696 // Compute the stride for each dimension.
2697 SmallVector<int64_t> counts;
2698 ArrayRef<int64_t> shape = sourceType.getShape();
2699 counts.reserve(shape.size());
2700 for (int64_t v : shape) {
2701 count = count / v;
2702 counts.push_back(count);
2703 }
2704
2705 // Slice the elements and construct a new attribute.
2706 SmallVector<Attribute> outValues;
2707 outValues.reserve(resultType.getNumElements());
2708 sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
2709 strides, &outValues);
2710 auto newAttr = DenseElementsAttr::get(resultType, outValues);
2711 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2712 return success();
2713 }
2714
2715private:
2716 /// This additionally controls whether the fold happens or not. Users can
2717 /// impose their heuristics in the function.
2719};
2720
2721} // namespace
2722
2724 RewritePatternSet &patterns,
2725 const ControlConstantExtractSliceFusionFn &controlFn) {
2726 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2727}
2728
2729/// Return the canonical type of the result of an extract_slice op.
2731 RankedTensorType operator()(ExtractSliceOp op,
2732 ArrayRef<OpFoldResult> mixedOffsets,
2733 ArrayRef<OpFoldResult> mixedSizes,
2734 ArrayRef<OpFoldResult> mixedStrides) {
2735 // Infer a tensor type without taking into account any rank reductions.
2736 RankedTensorType nonReducedType =
2737 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2738
2739 // Directly return the non-rank reduced type if there are no dropped
2740 // dims.
2741 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2742 if (droppedDims.none())
2743 return nonReducedType;
2744
2745 // Build the reduced shape, preserving the original rank reduction pattern.
2746 SmallVector<int64_t> targetShape;
2747 for (auto i : llvm::seq<int64_t>(mixedSizes.size()))
2748 if (!droppedDims.test(i))
2749 targetShape.push_back(nonReducedType.getDimSize(i));
2750
2751 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2752 nonReducedType.getEncoding());
2753 }
2754};
2755
2756/// A canonicalizer wrapper to replace ExtractSliceOps.
2758 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2759 ExtractSliceOp newOp) {
2760 Value replacement = newOp.getResult();
2761 if (replacement.getType() != op.getType())
2762 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2763 replacement);
2764 rewriter.replaceOp(op, replacement);
2765 }
2766};
2767
2768void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2769 MLIRContext *context) {
2770 results.add<
2771 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2772 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2773 ExtractSliceOpCastFolder>(context);
2774}
2775
2776//
2777static LogicalResult
2778foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2779 ShapedType shapedType) {
2780 OpBuilder b(op.getContext());
2781 for (OpFoldResult ofr : op.getMixedOffsets())
2782 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2783 return failure();
2784 // Rank-reducing noops only need to inspect the leading dimensions:
2785 // llvm::zip is appropriate.
2786 auto shape = shapedType.getShape();
2787 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2788 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2789 return failure();
2790 for (OpFoldResult ofr : op.getMixedStrides())
2791 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2792 return failure();
2793 return success();
2794}
2795
2796/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2797/// slice, we can return the InsertSliceOp's source directly.
2798// TODO: This only checks the immediate producer; extend to go up the
2799// insert/extract chain if the slices are disjoint.
2800static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2801 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2802
2803 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2804 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2805 insertOp.isSameAs(extractOp, isSame))
2806 return insertOp.getSource();
2807
2808 return {};
2809}
2810
2811OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2812 if (OpFoldResult reshapedSource = reshapeConstantSource(
2813 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2814 getResult().getType()))
2815 return reshapedSource;
2816 if (getSourceType() == getType() &&
2818 return this->getSource();
2819 if (Value slice = foldExtractAfterInsertSlice(*this))
2820 return slice;
2821
2822 return OpFoldResult();
2823}
2824
2826 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2827 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2828 unsigned rank = rankedTensorType.getRank();
2829 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2831 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2832 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2833 offsets, sizes, strides);
2834}
2835
2836//===----------------------------------------------------------------------===//
2837// InsertSliceOp
2838//===----------------------------------------------------------------------===//
2839
2840void InsertSliceOp::getAsmResultNames(
2841 function_ref<void(Value, StringRef)> setNameFn) {
2842 setNameFn(getResult(), "inserted_slice");
2843}
2844
2845// Build a InsertSliceOp with mixed static and dynamic entries.
2846void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2847 Value dest, ArrayRef<OpFoldResult> offsets,
2849 ArrayRef<OpFoldResult> strides,
2851 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2852 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2853 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2854 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2855 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2856 result.addAttributes(attrs);
2857 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2858 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2859 b.getDenseI64ArrayAttr(staticSizes),
2860 b.getDenseI64ArrayAttr(staticStrides));
2861}
2862
2863/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2864/// Range vector.
2865void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2866 Value dest, ArrayRef<Range> ranges,
2867 ArrayRef<NamedAttribute> attrs) {
2868 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2869 build(b, result, source, dest, offsets, sizes, strides, attrs);
2870}
2871
2872// Build a InsertSliceOp with dynamic entries.
2873void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2874 Value dest, ValueRange offsets, ValueRange sizes,
2875 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2876 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2877 offsets, [](Value v) -> OpFoldResult { return v; });
2878 SmallVector<OpFoldResult> sizeValues =
2879 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2880 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2881 strides, [](Value v) -> OpFoldResult { return v; });
2882 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2883}
2884
2885/// Rank-reducing type verification for both InsertSliceOp and
2886/// ParallelInsertSliceOp.
2888 RankedTensorType srcType, RankedTensorType dstType,
2889 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2890 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2891 // insert_slice is the inverse of extract_slice, use the same type
2892 // inference.
2893 RankedTensorType expected =
2894 ExtractSliceOp::inferResultType(dstType, staticSizes);
2895 if (expectedType)
2896 *expectedType = expected;
2897 return isRankReducedType(expected, srcType);
2898}
2899
2900/// Verifier for InsertSliceOp.
2901LogicalResult InsertSliceOp::verify() {
2902 // Verify result type against inferred type.
2903 RankedTensorType expectedType;
2905 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2906 getStaticSizes(), getStaticStrides(), &expectedType);
2908 return produceSliceErrorMsg(result, *this, expectedType);
2909
2910 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2911 // to the destination tensor.
2912 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2913 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2914 getStaticStrides(), /*generateErrorMessage=*/true);
2915 if (!boundsResult.isValid)
2916 return getOperation()->emitError(boundsResult.errorMessage);
2917
2918 return success();
2919}
2920
2921/// If we have two consecutive InsertSliceOp writing to the same slice, we
2922/// can mutate the second InsertSliceOp's destination to the first one's.
2923///
2924/// Example:
2925///
2926/// ```mlir
2927/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2928/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2929/// ```
2930///
2931/// folds into:
2932///
2933/// ```mlir
2934/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2935/// ```
2936///
2937/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2938static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2939 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2940
2941 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2942 if (!prevInsertOp ||
2943 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2944 !prevInsertOp.isSameAs(insertOp, isSame))
2945 return failure();
2946
2947 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2948 return success();
2949}
2950
2951/// Folds round-trip extract/insert slice op pairs.
2952/// Example:
2953/// ```mlir
2954/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2955/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2956/// ```
2957/// can be folded into %val.
2958static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2959 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2960
2961 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2962 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2963 !extractOp.isSameAs(insertOp, isSame))
2964 return nullptr;
2965
2966 return extractOp.getSource();
2967}
2968
2969OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2970 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2971 getSourceType() == getType() &&
2973 return this->getSource();
2974 if (succeeded(foldInsertAfterInsertSlice(*this)))
2975 return getResult();
2976 if (auto result = foldInsertAfterExtractSlice(*this))
2977 return result;
2978 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2979 return getDest();
2980 return OpFoldResult();
2981}
2982
2983LogicalResult InsertSliceOp::reifyResultShapes(
2984 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2985 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2986 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2987 return success();
2988}
2989
2990namespace {
2991/// Pattern to rewrite a insert_slice op with constant arguments.
2992///
2993/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2994template <typename InsertOpTy>
2995class InsertSliceOpConstantArgumentFolder final
2996 : public OpRewritePattern<InsertOpTy> {
2997public:
2998 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2999
3000 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3001 PatternRewriter &rewriter) const override {
3002 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3003 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3004 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3005
3006 // No constant operands were folded, just return;
3007 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
3008 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
3009 failed(foldDynamicStrideList(mixedStrides)))
3010 return failure();
3011
3012 // Pattern does not apply if the produced op would not verify.
3013 SliceBoundsVerificationResult sliceResult =
3014 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
3015 mixedOffsets, mixedSizes, mixedStrides);
3016 if (!sliceResult.isValid)
3017 return failure();
3018
3019 // Create the new op in canonical form.
3020 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3021 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3022 mixedSizes);
3023 Value toInsert = insertSliceOp.getSource();
3024 if (sourceType != insertSliceOp.getSourceType()) {
3025 OpBuilder::InsertionGuard g(rewriter);
3026 // The only difference between InsertSliceOp and ParallelInsertSliceOp
3027 // is that the insertion point is just before the InParallelOp in
3028 // the parallel case.
3029 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3030 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3031 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3032 sourceType, toInsert);
3033 }
3034 rewriter.replaceOpWithNewOp<InsertOpTy>(
3035 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3036 mixedSizes, mixedStrides);
3037 return success();
3038 }
3039};
3040
3041/// Fold tensor_casts with insert_slice operations. If the source or
3042/// destination tensor is a tensor_cast that removes static type information,
3043/// the cast is folded into the insert_slice operation. E.g.:
3044///
3045/// ```mlir
3046/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3047/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3048/// ```
3049///
3050/// folds into:
3051///
3052/// ```mlir
3053/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3054/// ```
3055///
3056/// Note: When folding a cast on the destination tensor, the result of the
3057/// insert_slice operation is casted to ensure that the type of the result did
3058/// not change.
3059///
3060/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3061template <typename InsertOpTy>
3062struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3063 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3064
3065 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3066 PatternRewriter &rewriter) const override {
3067 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3068 return matchPattern(operand, matchConstantIndex());
3069 }))
3070 return failure();
3071
3072 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3073 auto castOp = v.getDefiningOp<tensor::CastOp>();
3074 if (!castOp || !canFoldIntoConsumerOp(castOp))
3075 return std::nullopt;
3076 return castOp.getSource();
3077 };
3078 std::optional<Value> sourceCastSource =
3079 getSourceOfCastOp(insertSliceOp.getSource());
3080 std::optional<Value> destCastSource =
3081 getSourceOfCastOp(insertSliceOp.getDest());
3082 if (!sourceCastSource && !destCastSource)
3083 return failure();
3084
3085 auto src =
3086 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3087 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3088 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3089 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3090 if (!srcType || !dstType)
3091 return failure();
3092
3093 // The tensor.cast source could have additional static information not seen
3094 // in the insert slice op static sizes, so we ignore dynamic dims when
3095 // computing the rank reduction mask.
3096 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3097 auto rankReductionMask = computeRankReductionMask(
3098 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3099 if (!rankReductionMask.has_value())
3100 return failure();
3101 // Replace dimensions in the insert slice op with corresponding static dims
3102 // from the cast source type. If the insert slice sizes have static dims
3103 // that are not static in the tensor.cast source (i.e., when the cast op
3104 // casts a dynamic dim to static), the dim should not be replaced, and the
3105 // pattern will fail later in `verifyInsertSliceOp`.
3106 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3107 int64_t rankReducedIdx = 0;
3108 for (auto [idx, size] : enumerate(staticSizes)) {
3109 if (!rankReductionMask.value().contains(idx) &&
3110 !srcType.isDynamicDim(rankReducedIdx)) {
3111 mixedSizes[idx] = getAsIndexOpFoldResult(
3112 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3113 size = srcType.getDimSize(rankReducedIdx++);
3114 }
3115 }
3116
3117 // Pattern does not apply if the produced op would not verify.
3118 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3119 staticSizes, insertSliceOp.getStaticStrides()) !=
3120 SliceVerificationResult::Success)
3121 return failure();
3122 SliceBoundsVerificationResult sliceResult =
3123 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3124 mixedSizes, insertSliceOp.getMixedStrides());
3125 if (!sliceResult.isValid)
3126 return failure();
3127
3128 Operation *replacement =
3129 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3130 insertSliceOp.getMixedOffsets(), mixedSizes,
3131 insertSliceOp.getMixedStrides());
3132
3133 // In the parallel case there is no result and so nothing to cast.
3134 bool isParallelInsert =
3135 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3136 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3137 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3138 insertSliceOp.getDestType(),
3139 replacement->getResult(0));
3140 }
3141 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3142 return success();
3143 }
3144};
3145
3146/// If additional static type information can be deduced from a insert_slice's
3147/// size operands, insert an explicit cast of the op's source operand. This
3148/// enables other canonicalization patterns that are matching for tensor_cast
3149/// ops such as `ForOpTensorCastFolder` in SCF.
3150///
3151/// Example:
3152///
3153/// ```mlir
3154/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3155/// : tensor<?x?xf32> into ...
3156/// ```
3157///
3158/// folds into:
3159///
3160/// ```mlir
3161/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3162/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3163/// : tensor<64x64xf32> into ...
3164/// ```
3165///
3166/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3167template <typename InsertOpTy>
3168struct InsertSliceOpSourceCastInserter final
3169 : public OpRewritePattern<InsertOpTy> {
3170 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3171
3172 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3173 PatternRewriter &rewriter) const override {
3174 RankedTensorType srcType = insertSliceOp.getSourceType();
3175 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3176 return failure();
3177 SmallVector<int64_t> newSrcShape(srcType.getShape());
3178 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3179 if (std::optional<int64_t> constInt =
3180 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3181 // Bail on invalid IR.
3182 if (*constInt < 0)
3183 return failure();
3184 newSrcShape[i] = *constInt;
3185 }
3186 }
3187 if (!hasValidSizesOffsets(newSrcShape))
3188 return failure();
3189
3190 RankedTensorType newSrcType = RankedTensorType::get(
3191 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3192 if (srcType == newSrcType ||
3193 !preservesStaticInformation(srcType, newSrcType) ||
3194 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3195 return failure();
3196
3197 // newSrcType is:
3198 // 1) Different from srcType.
3199 // 2) "More static" than srcType.
3200 // 3) Cast-compatible with srcType.
3201 // Insert the cast.
3202 OpBuilder::InsertionGuard g(rewriter);
3203 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3204 // that the insertion point is just before the InParallelOp in the
3205 // parallel case.
3206 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3207 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3208 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3209 newSrcType, insertSliceOp.getSource());
3210 rewriter.replaceOpWithNewOp<InsertOpTy>(
3211 insertSliceOp, cast, insertSliceOp.getDest(),
3212 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3213 insertSliceOp.getMixedStrides());
3214 return success();
3215 }
3216};
3217} // namespace
3218
3219llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3220 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3221}
3222
3223void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3224 MLIRContext *context) {
3225 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3226 InsertSliceOpCastFolder<InsertSliceOp>,
3227 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3228}
3229
3231 Location loc,
3232 Value tensor,
3233 Value dest) {
3234 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3235 unsigned rank = rankedTensorType.getRank();
3236 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3237 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3238 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3239 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3240 sizes, strides);
3241}
3242
3243//===----------------------------------------------------------------------===//
3244// PadOp
3245//===----------------------------------------------------------------------===//
3246
3247void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3248 setNameFn(getResult(), "padded");
3249}
3250
3251LogicalResult PadOp::verify() {
3252 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3253 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3254 auto expectedType =
3255 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3256 if (!expectedType) {
3257 return emitError("failed to infer expectedType from sourceType ")
3258 << sourceType << ", specified resultType is " << resultType;
3259 }
3260 if (resultType.getRank() != expectedType.getRank()) {
3261 return emitError("specified type ")
3262 << resultType << " does not match the inferred type "
3263 << expectedType;
3264 }
3265 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3266 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3267 continue;
3268 if (expectedType.isDynamicDim(i))
3269 continue;
3270 return emitError("specified type ")
3271 << resultType << " does not match the inferred type "
3272 << expectedType;
3273 }
3274
3275 return success();
3276}
3277
3278LogicalResult PadOp::verifyRegions() {
3279 auto &region = getRegion();
3280 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3281 Block &block = region.front();
3282 if (block.getNumArguments() != rank)
3283 return emitError("expected the block to have ") << rank << " arguments";
3284
3285 // Note: the number and type of yield values are checked in the YieldOp.
3286 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3287 if (!en.value().isIndex())
3288 return emitOpError("expected block argument ")
3289 << (en.index() + 1) << " to be an index";
3290 }
3291
3292 // Ensure that the region yields an element of the right type.
3293 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3294 if (yieldOp.getValue().getType() !=
3295 llvm::cast<ShapedType>(getType()).getElementType())
3296 return emitOpError("expected yield type to match shape element type");
3297
3298 return success();
3299}
3300
3301RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3302 ArrayRef<int64_t> staticLow,
3303 ArrayRef<int64_t> staticHigh,
3304 ArrayRef<int64_t> resultShape) {
3305 unsigned rank = sourceType.getRank();
3306 if (staticLow.size() != rank)
3307 return RankedTensorType();
3308 if (staticHigh.size() != rank)
3309 return RankedTensorType();
3310 if (!resultShape.empty() && resultShape.size() != rank)
3311 return RankedTensorType();
3312
3313 SmallVector<int64_t, 4> inferredShape;
3314 for (auto i : llvm::seq<unsigned>(0, rank)) {
3315 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3316 staticHigh[i] == ShapedType::kDynamic) {
3317 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3318 : resultShape[i]);
3319 } else {
3320 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3321 assert((resultShape.empty() || size == resultShape[i] ||
3322 resultShape[i] == ShapedType::kDynamic) &&
3323 "mismatch between inferred shape and result shape");
3324 inferredShape.push_back(size);
3325 }
3326 }
3327
3328 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3329}
3330
3331void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3332 Value source, ArrayRef<int64_t> staticLow,
3333 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3334 bool nofold, ArrayRef<NamedAttribute> attrs) {
3335 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3336 if (!resultType)
3337 resultType = inferResultType(sourceType, staticLow, staticHigh);
3338 result.addAttributes(attrs);
3339 build(b, result, resultType, source, low, high,
3340 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3341 nofold ? b.getUnitAttr() : UnitAttr());
3342}
3343
3344void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3345 Value source, ValueRange low, ValueRange high, bool nofold,
3346 ArrayRef<NamedAttribute> attrs) {
3347 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3348 unsigned rank = sourceType.getRank();
3349 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3350 build(b, result, resultType, source, staticVector, staticVector, low, high,
3351 nofold, attrs);
3352}
3353
3354void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3355 Value source, ArrayRef<OpFoldResult> low,
3356 ArrayRef<OpFoldResult> high, bool nofold,
3357 ArrayRef<NamedAttribute> attrs) {
3358 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3359 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3360 SmallVector<int64_t, 4> staticLow, staticHigh;
3361 // staticLow and staticHigh have full information of the padding config.
3362 // This will grow staticLow and staticHigh with 1 value. If the config is
3363 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3364 // value as well.
3365 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3366 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3367 if (!resultType) {
3368 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3369 }
3370 assert(llvm::isa<RankedTensorType>(resultType));
3371 result.addAttributes(attrs);
3372 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3373 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3374 nofold ? b.getUnitAttr() : UnitAttr());
3375}
3376
3377void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3378 Value source, ArrayRef<OpFoldResult> low,
3379 ArrayRef<OpFoldResult> high, Value constantPadValue,
3380 bool nofold, ArrayRef<NamedAttribute> attrs) {
3381 build(b, result, resultType, source, low, high, nofold, attrs);
3382
3383 // Add a region and a block to yield the pad value.
3384 Region *region = result.regions[0].get();
3385 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3386 Repeated<Type> blockArgTypes(sourceRank, b.getIndexType());
3387 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3388
3389 // `builder.createBlock` changes the insertion point within the block. Create
3390 // a guard to reset the insertion point of the builder after it is destroyed.
3391 OpBuilder::InsertionGuard guard(b);
3392 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3393 tensor::YieldOp::create(b, result.location, constantPadValue);
3394}
3395
3396llvm::SmallBitVector PadOp::getPaddedDims() {
3397 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3398 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3399 for (const auto &en : enumerate(paddingWidths))
3400 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3401 paddedDims.set(en.index());
3402 };
3403 extractPaddedDims(getMixedLowPad());
3404 extractPaddedDims(getMixedHighPad());
3405 return paddedDims;
3406}
3407
3408namespace {
3409// Folds tensor.pad when padding is static zeros and the attribute
3410// doesn't request otherwise.
3411struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3412 using OpRewritePattern<PadOp>::OpRewritePattern;
3413
3414 LogicalResult matchAndRewrite(PadOp padTensorOp,
3415 PatternRewriter &rewriter) const override {
3416 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3417 return failure();
3418 if (padTensorOp.getNofold())
3419 return failure();
3420 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3421 padTensorOp, padTensorOp.getResult().getType(),
3422 padTensorOp.getSource());
3423 return success();
3424 }
3425};
3426
3427// Fold CastOp into PadOp when adding static information.
3428struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3429 using OpRewritePattern<PadOp>::OpRewritePattern;
3430
3431 LogicalResult matchAndRewrite(PadOp padTensorOp,
3432 PatternRewriter &rewriter) const override {
3433 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3434 if (!tensor::canFoldIntoConsumerOp(castOp))
3435 return failure();
3436
3437 auto newResultType = PadOp::inferResultType(
3438 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3439 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3440 padTensorOp.getResultType().getShape());
3441
3442 if (newResultType == padTensorOp.getResultType()) {
3443 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3444 padTensorOp.getSourceMutable().assign(castOp.getSource());
3445 });
3446 } else {
3447 auto newOp = PadOp::create(
3448 rewriter, padTensorOp->getLoc(), newResultType,
3449 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3450 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3451 padTensorOp.getHigh(), padTensorOp.getNofold(),
3452 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3453 IRMapping mapper;
3454 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3455
3456 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3457 padTensorOp, padTensorOp.getResultType(), newOp);
3458 }
3459 return success();
3460 }
3461};
3462
3463// Fold CastOp using the result of PadOp back into the latter if it adds
3464// static information.
3465struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3466 using OpRewritePattern<PadOp>::OpRewritePattern;
3467
3468 LogicalResult matchAndRewrite(PadOp padTensorOp,
3469 PatternRewriter &rewriter) const override {
3470 if (!padTensorOp.getResult().hasOneUse())
3471 return failure();
3472 auto tensorCastOp =
3473 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3474 if (!tensorCastOp)
3475 return failure();
3476 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3477 tensorCastOp.getDest().getType()))
3478 return failure();
3479
3480 auto replacementOp = PadOp::create(
3481 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3482 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3483 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3484 padTensorOp.getHigh(), padTensorOp.getNofold(),
3485 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3486 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3487
3488 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3489 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3490 return success();
3491 }
3492};
3493
3494/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3495/// different dimensions. The pattern applies if the following preconditions
3496/// hold:
3497/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3498/// 2) the tensor::ExtractSliceOps have only unit-strides,
3499/// 3) the tensor::PadOps perform only high-padding,
3500/// 4) the tensor::PadOps have the same constant padding value,
3501/// 5) the tensor::PadOps do not have common padding dimensions,
3502/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3503/// zero-offset for every dimension.
3504/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3505/// the
3506/// padded source dimensions.
3507///
3508/// Example:
3509///
3510/// ```mlir
3511/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3512/// : tensor<64x64xf32> to tensor<?x64xf32>
3513/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3514/// } : tensor<?x64xf32> to tensor<8x64xf32>
3515/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3516/// : tensor<8x64xf32> to tensor<8x?xf32>
3517/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3518/// } : tensor<8x?xf32> to tensor<8x4xf32>
3519/// ```
3520///
3521/// folds into:
3522///
3523/// ```mlir
3524/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3525/// : tensor<64x64xf32> to tensor<?x?xf32>
3526/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3527/// } : tensor<?x?xf32> to tensor<8x4xf32>
3528/// ```
3529struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3530 using OpRewritePattern<PadOp>::OpRewritePattern;
3531
3532 LogicalResult matchAndRewrite(PadOp padOp,
3533 PatternRewriter &rewriter) const override {
3534 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3535 if (!innerSliceOp)
3536 return failure();
3537 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3538 if (!outerPadOp || outerPadOp.getNofold())
3539 return failure();
3540 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3541 if (!outerSliceOp)
3542 return failure();
3543
3544 // 1) Fail if the chain is rank-reducing.
3545 int64_t rank = padOp.getSourceType().getRank();
3546 if (outerSliceOp.getSourceType().getRank() != rank) {
3547 return rewriter.notifyMatchFailure(padOp,
3548 "cannot fold rank-reducing chain");
3549 }
3550
3551 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3552 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3553 return rewriter.notifyMatchFailure(
3554 padOp, "cannot fold non-unit stride ExtractSliceOps");
3555 }
3556
3557 // 3) Fail if the tensor::PadOps have non-zero low padding.
3558 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3559 return rewriter.notifyMatchFailure(padOp,
3560 "cannot fold PadOps with low padding");
3561 }
3562
3563 // 4) Fail if the tensor::PadOps padding values do not match.
3564 Attribute innerAttr, outerAttr;
3565 Value innerValue = padOp.getConstantPaddingValue();
3566 Value outerValue = outerPadOp.getConstantPaddingValue();
3567 if (!innerValue || !outerValue ||
3568 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3569 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3570 innerAttr != outerAttr) {
3571 return rewriter.notifyMatchFailure(
3572 padOp, "cannot fold PadOps with different padding values");
3573 }
3574
3575 // 5) Fail if a dimension is padded by both tensor::PadOps.
3576 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3577 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3578 if (innerDims.anyCommon(outerDims)) {
3579 return rewriter.notifyMatchFailure(
3580 padOp, "cannot fold PadOps with common padding dimensions");
3581 }
3582
3583 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3584 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3585 // for every dimension, and use the offset the other pair. Fail if no
3586 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3587 // exists.
3588 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3589 for (auto en : enumerate(newOffsets)) {
3590 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3591 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3592 if (!innerDims.test(en.index()) &&
3593 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3594 en.value() = outerOffset;
3595 continue;
3596 }
3597 if (!outerDims.test(en.index()) &&
3598 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3599 en.value() = innerOffset;
3600 continue;
3601 }
3602 return rewriter.notifyMatchFailure(
3603 padOp, "cannot find zero-offset and zero-padding pair");
3604 }
3605
3606 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3607 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3608 // outer tensor::PadOp and fail if the size of the inner
3609 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3610 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3611 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3612 for (auto en : enumerate(newSizes)) {
3613 if (!outerDims.test(en.index()))
3614 continue;
3615 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3616 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3617 assert(ShapedType::isStatic(sourceSize) &&
3618 "expected padded dimension to have a static size");
3619 if (getConstantIntValue(sliceSize) != sourceSize) {
3620 return rewriter.notifyMatchFailure(
3621 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3622 "match the size of the outer padding");
3623 }
3624 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3625 }
3626
3627 // Combine the high paddings of the two tensor::PadOps.
3628 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3629 for (auto en : enumerate(newHighPad)) {
3630 if (innerDims.test(en.index()))
3631 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3632 if (outerDims.test(en.index()))
3633 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3634 }
3635
3636 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3637 // the two paddings in one step.
3638 auto newSliceOp = ExtractSliceOp::create(
3639 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3640 newSizes, innerSliceOp.getMixedStrides());
3641 auto newPadOp = PadOp::create(
3642 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3643 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3644 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3645 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3646 newPadOp.getRegion().begin());
3647 rewriter.replaceOp(padOp, newPadOp.getResult());
3648 return success();
3649 }
3650};
3651
3652struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3653 using OpRewritePattern<PadOp>::OpRewritePattern;
3654
3655 LogicalResult matchAndRewrite(PadOp padTensorOp,
3656 PatternRewriter &rewriter) const override {
3657 Value input = padTensorOp.getSource();
3658 if (!llvm::isa<RankedTensorType>(input.getType()))
3659 return failure();
3660 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3661 auto inputRank = inputDims.size();
3662
3663 auto oldResultType =
3664 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3665 if (!oldResultType)
3666 return failure();
3667
3668 auto outputDims = oldResultType.getShape();
3669
3670 // Extract the static info from the high and low operands.
3671 SmallVector<int64_t> constOperandsLow;
3672 SmallVector<Value> newLows;
3673 for (auto operand : padTensorOp.getLow()) {
3674 APSInt intOp;
3675 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3676 constOperandsLow.push_back(ShapedType::kDynamic);
3677 newLows.push_back(operand);
3678 continue;
3679 }
3680 constOperandsLow.push_back(intOp.getExtValue());
3681 }
3682 SmallVector<int64_t> constOperandsHigh;
3683 SmallVector<Value> newHighs;
3684 for (auto operand : padTensorOp.getHigh()) {
3685 APSInt intOp;
3686 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3687 constOperandsHigh.push_back(ShapedType::kDynamic);
3688 newHighs.push_back(operand);
3689 continue;
3690 }
3691 constOperandsHigh.push_back(intOp.getExtValue());
3692 }
3693
3694 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3695 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3696
3697 // Verify the op is well-formed.
3698 if (inputDims.size() != outputDims.size() ||
3699 inputDims.size() != constLow.size() ||
3700 inputDims.size() != constHigh.size())
3701 return failure();
3702
3703 auto lowCount = 0;
3704 auto highCount = 0;
3705 for (size_t i = 0; i < inputRank; i++) {
3706 if (constLow[i] == ShapedType::kDynamic)
3707 constLow[i] = constOperandsLow[lowCount++];
3708 if (constHigh[i] == ShapedType::kDynamic)
3709 constHigh[i] = constOperandsHigh[highCount++];
3710 }
3711
3712 auto staticLow = ArrayRef<int64_t>(constLow);
3713 auto staticHigh = ArrayRef<int64_t>(constHigh);
3714
3715 // Calculate the output sizes with the static information.
3716 SmallVector<int64_t> newOutDims;
3717 for (size_t i = 0; i < inputRank; i++) {
3718 if (outputDims[i] == ShapedType::kDynamic) {
3719 newOutDims.push_back(
3720 (staticLow[i] == ShapedType::kDynamic ||
3721 staticHigh[i] == ShapedType::kDynamic ||
3722 inputDims[i] == ShapedType::kDynamic
3723 ? ShapedType::kDynamic
3724 : inputDims[i] + staticLow[i] + staticHigh[i]));
3725 } else {
3726 newOutDims.push_back(outputDims[i]);
3727 }
3728 }
3729
3730 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3731 llvm::all_of(newOutDims,
3732 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3733 return failure();
3734
3735 // Rewrite the op using the new static type.
3736 auto newResultType = RankedTensorType::get(
3737 newOutDims, padTensorOp.getType().getElementType());
3738 auto newOp = PadOp::create(
3739 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3740 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3741 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3742
3743 IRMapping mapper;
3744 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3745 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3746 newOp);
3747
3748 return success();
3749 }
3750};
3751
3752/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3753///
3754/// Example:
3755///
3756/// ```mlir
3757/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3758/// tensor.yield %val
3759/// } : tensor<1x2xf32> to tensor<2x5xf32>
3760/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3761/// tensor.yield %val
3762/// } : tensor<1x5xf32> to tensor<5x7xf32>
3763/// ```
3764///
3765/// folds into:
3766///
3767/// ```mlir
3768/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3769/// tensor.yield %val
3770/// } : tensor<1x2xf32> to tensor<5x7xf32>
3771/// ```
3772struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3773 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3774
3775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3776 PatternRewriter &rewriter) const override {
3777 if (padOp.getNofold()) {
3778 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3779 }
3780
3781 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3782 if (!producerPad || producerPad.getNofold()) {
3783 return rewriter.notifyMatchFailure(
3784 padOp, "producer is not a foldable tensor.pad op");
3785 }
3786
3787 // Fail if the tensor::PadOps padding values do not match.
3788 Value consumerPadValue = padOp.getConstantPaddingValue();
3789 Value producerPadValue = producerPad.getConstantPaddingValue();
3790 if (!consumerPadValue || !producerPadValue ||
3791 consumerPadValue != producerPadValue) {
3792 return rewriter.notifyMatchFailure(
3793 padOp,
3794 "cannot fold PadOps with different or non-constant padding values");
3795 }
3796
3797 Location loc = padOp.getLoc();
3798 AffineExpr d0, d1;
3799 bindDims(rewriter.getContext(), d0, d1);
3800
3801 // Combine the low/high paddings of the two tensor::PadOps.
3802 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3803 ArrayRef<OpFoldResult> producerPaddings) {
3804 SmallVector<OpFoldResult> sumPaddings;
3805 for (auto [consumerIndex, producerIndex] :
3806 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3807 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3808 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3809 }
3810 return sumPaddings;
3811 };
3812
3813 SmallVector<OpFoldResult> newHighPad =
3814 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3815 SmallVector<OpFoldResult> newLowPad =
3816 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3817
3818 auto newPadOp = tensor::PadOp::create(
3819 rewriter, padOp.getLoc(), padOp.getResultType(),
3820 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3821 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3822 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3823 newPadOp.getRegion().begin());
3824 rewriter.replaceOp(padOp, newPadOp.getResult());
3825 return success();
3826 }
3827};
3828
3829} // namespace
3830
3831LogicalResult
3832PadOp::reifyResultShapes(OpBuilder &b,
3833 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3834 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3835 SmallVector<OpFoldResult> lp = getMixedLowPad();
3836 SmallVector<OpFoldResult> hp = getMixedHighPad();
3837 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3838 if (!getType().isDynamicDim(i)) {
3839 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3840 continue;
3841 }
3842 Location loc = getLoc();
3843 Value dim = b.createOrFold<tensor::DimOp>(
3844 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3845
3846 AffineExpr d0, d1, d2;
3847 bindDims(b.getContext(), d0, d1, d2);
3848 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3849 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3850 }
3851 return success();
3852}
3853
3854void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3855 MLIRContext *context) {
3856 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3857 FoldOrthogonalPaddings, FoldStaticPadding,
3858 FoldConsecutiveConstantPadding>(context);
3859}
3860
3861/// Return the padding value of the PadOp if it constant. In this context,
3862/// "constant" means an actual constant or "defined outside of the block".
3863///
3864/// Values are considered constant in three cases:
3865/// - A ConstantLike value.
3866/// - A basic block argument from a different block.
3867/// - A value defined outside of the block.
3868///
3869/// If the padding value is not constant, an empty Value is returned.
3870Value PadOp::getConstantPaddingValue() {
3871 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3872 if (!yieldOp)
3873 return {};
3874 Value padValue = yieldOp.getValue();
3875 // Check if yield value is a constant.
3876 if (matchPattern(padValue, m_Constant()))
3877 return padValue;
3878 // Check if yield value is defined inside the PadOp block.
3879 if (padValue.getParentBlock() == &getRegion().front())
3880 return {};
3881 // Else: Yield value defined outside of the PadOp block.
3882 return padValue;
3883}
3884
3885OpFoldResult PadOp::fold(FoldAdaptor) {
3886 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3887 !getNofold())
3888 return getSource();
3889 return {};
3890}
3891
3892//===----------------------------------------------------------------------===//
3893// ParallelInsertSliceOp
3894//===----------------------------------------------------------------------===//
3895
3896OpResult ParallelInsertSliceOp::getTiedOpResult() {
3897 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3898 for (const auto &it :
3899 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3900 Operation &nextOp = it.value();
3901 if (&nextOp == getOperation())
3902 return parallelCombiningParent.getParentResult(it.index());
3903 }
3904 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3905}
3906
3907// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3908void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3909 Value source, Value dest,
3910 ArrayRef<OpFoldResult> offsets,
3911 ArrayRef<OpFoldResult> sizes,
3912 ArrayRef<OpFoldResult> strides,
3913 ArrayRef<NamedAttribute> attrs) {
3914 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3915 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3916 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3917 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3918 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3919 result.addAttributes(attrs);
3920 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3921 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3922 b.getDenseI64ArrayAttr(staticSizes),
3923 b.getDenseI64ArrayAttr(staticStrides));
3924}
3925
3926/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3927/// packed into a Range vector.
3928void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3929 Value source, Value dest,
3930 ArrayRef<Range> ranges,
3931 ArrayRef<NamedAttribute> attrs) {
3932 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3933 build(b, result, source, dest, offsets, sizes, strides, attrs);
3934}
3935
3936// Build a ParallelInsertSliceOp with dynamic entries.
3937void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3938 Value source, Value dest, ValueRange offsets,
3939 ValueRange sizes, ValueRange strides,
3940 ArrayRef<NamedAttribute> attrs) {
3941 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3942 offsets, [](Value v) -> OpFoldResult { return v; });
3943 SmallVector<OpFoldResult> sizeValues =
3944 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3945 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3946 strides, [](Value v) -> OpFoldResult { return v; });
3947 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3948}
3949
3950// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3951// to 0, strides set to 1 and inferred result type.
3952void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3953 Value dest, ArrayRef<OpFoldResult> sizes,
3954 ArrayRef<NamedAttribute> attrs) {
3955 Attribute zeroIdxAttr = b.getIndexAttr(0);
3956 Attribute oneIdxAttr = b.getIndexAttr(1);
3957 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3958 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3959 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3960}
3961
3962LogicalResult ParallelInsertSliceOp::verify() {
3963 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3964 return this->emitError("expected InParallelOpInterface parent, got:")
3965 << *(getOperation()->getParentOp());
3966
3967 // Verify result type against inferred type.
3968 RankedTensorType expectedType;
3970 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3971 getStaticSizes(), getStaticStrides(), &expectedType);
3973 return produceSliceErrorMsg(result, *this, expectedType);
3974
3975 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3976 // to the destination tensor.
3977 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3978 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3979 getStaticStrides(), /*generateErrorMessage=*/true);
3980 if (!boundsResult.isValid)
3981 return getOperation()->emitError(boundsResult.errorMessage);
3982
3983 return success();
3984}
3985
3986void ParallelInsertSliceOp::getCanonicalizationPatterns(
3987 RewritePatternSet &results, MLIRContext *context) {
3988 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3989 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3990 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3991}
3992
3993llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3994 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3995}
3996
3997// ParallelCombiningOpInterface implementation.
3998MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3999 return getDestMutable();
4000}
4001
4002Operation *ParallelInsertSliceOp::getIteratingParent() {
4003 // Return the parent InParallelOpInterface's parent.
4004 if (auto combiningOp =
4005 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4006 return combiningOp->getParentOp();
4007 return nullptr;
4008}
4009
4010//===----------------------------------------------------------------------===//
4011// ScatterOp
4012//===----------------------------------------------------------------------===//
4013
4014void ScatterOp::getAsmResultNames(
4015 function_ref<void(Value, StringRef)> setNameFn) {
4016 setNameFn(getResult(), "scatter");
4017}
4018
4019LogicalResult ScatterOp::verify() {
4020 int64_t destRank = getDestType().getRank();
4021 ArrayRef<int64_t> scatterDims = getScatterDims();
4022 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
4023 getIndicesType().getShape(), destRank,
4024 "scatter", "dest")))
4025 return failure();
4026
4027 if (!getUnique())
4028 return emitOpError("requires 'unique' attribute to be set");
4029 // TODO: we could also check statically that there are fewer leading index
4030 // tensor dims than the dest dims. If this is not the case, the unique
4031 // attribute cannot be true.
4032
4033 // Use the GatherOp::inferResultType on the `dest` type and verify the
4034 // expected type matches the source type.
4035 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4036 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
4037 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4038 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
4039 if (getSourceType() != expectedSourceType &&
4040 getSourceType() != expectedRankReducedSourceType) {
4041 return emitOpError("source type "
4042 "mismatch: "
4043 "expected ")
4044 << expectedSourceType << " or its rank-reduced variant "
4045 << expectedRankReducedSourceType << " (got: " << getSourceType()
4046 << ")";
4047 }
4048
4049 return success();
4050}
4051
4052//===----------------------------------------------------------------------===//
4053// SplatOp
4054//===----------------------------------------------------------------------===//
4055
4056void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4057 Type aggregateType, ValueRange dynamicSizes) {
4058 build(builder, result, aggregateType, element, dynamicSizes);
4059}
4060
4061void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4062 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4063 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4064 build(builder, result, aggregateType, element, dynamicSizes);
4065}
4066
4067void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4068 ArrayRef<OpFoldResult> sizes) {
4069 SmallVector<int64_t> staticShape;
4070 SmallVector<Value> dynamicSizes;
4071 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4072 build(builder, result, element, staticShape, dynamicSizes);
4073}
4074
4075void SplatOp::getAsmResultNames(
4076 function_ref<void(Value, StringRef)> setNameFn) {
4077 setNameFn(getResult(), "splat");
4078}
4079
4080LogicalResult SplatOp::verify() {
4081 return verifyDynamicDimensionCount(getOperation(), getType(),
4082 getDynamicSizes());
4083}
4084
4085LogicalResult
4086SplatOp::reifyResultShapes(OpBuilder &builder,
4087 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4089 unsigned ctr = 0;
4090 for (int64_t i = 0; i < getType().getRank(); ++i) {
4091 if (getType().isDynamicDim(i)) {
4092 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4093 } else {
4094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4095 }
4096 }
4097 return success();
4098}
4099
4100OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4101 auto constOperand = adaptor.getInput();
4102 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4103 return {};
4104
4105 // Do not fold if the splat is not statically shaped
4106 if (!getType().hasStaticShape())
4107 return {};
4108
4109 // SplatElementsAttr::get treats single value for second arg as being a
4110 // splat.
4111 return SplatElementsAttr::get(getType(), {constOperand});
4112}
4113
4114//===----------------------------------------------------------------------===//
4115// Common Canonicalizers and Folders.
4116//===----------------------------------------------------------------------===//
4117static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4118 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4119 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4120 // might need special handling of attached regions.
4121 if (isa<InsertSliceOp>(op.getOperation()) ||
4122 isa<LoopLikeOpInterface>(op.getOperation()))
4123 return false;
4124
4126}
4127
4128/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4129/// the `tensor.cast` has source that is more static than the consuming op.
4130///
4131/// Example:
4132/// ```mlir
4133/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4134/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4135/// ```
4136///
4137/// folds into:
4138///
4139/// ```mlir
4140/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4141/// ```
4142/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4143/// can add the pattern to their canonicalizers.
4145 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4147 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4148
4149 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4150 PatternRewriter &rewriter) const override {
4151
4152 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4153 // for that instead.
4154 if (!foldTensorCastPrecondition(op) ||
4155 isa<linalg::RelayoutOpInterface>(*op))
4156 return failure();
4157
4158 SmallVector<Type> newResultTypes(op->getResultTypes());
4159 SmallVector<Value> newOperands =
4160 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4161
4162 // Clone op
4163 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4164
4165 SmallVector<Value, 4> replacements;
4166 replacements.reserve(newOp->getNumResults());
4167 for (auto [oldResult, newResult] :
4168 llvm::zip(op->getResults(), newOp->getResults())) {
4169 if (newResult.getType() != oldResult.getType()) {
4170 replacements.push_back(tensor::CastOp::create(
4171 rewriter, op->getLoc(), oldResult.getType(), newResult));
4172 } else {
4173 replacements.push_back(newResult);
4174 }
4175 }
4176 rewriter.replaceOp(op, replacements);
4177
4178 return success();
4179 }
4180};
4181
4182//===----------------------------------------------------------------------===//
4183// TensorDialect
4184//===----------------------------------------------------------------------===//
4185
4186void TensorDialect::getCanonicalizationPatterns(
4187 RewritePatternSet &results) const {
4188 results.add<FoldTensorCastProducerOp>(getContext());
4189}
4190
4191//===----------------------------------------------------------------------===//
4192// TableGen'd op method definitions
4193//===----------------------------------------------------------------------===//
4194
4195#define GET_OP_CLASSES
4196#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:382
MLIRContext * getContext() const
Definition Builders.h:56
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:78
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:93
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.