MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/Repeated.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/MathExtras.h"
42#include <optional>
43
44using namespace mlir;
45using namespace mlir::tensor;
46
47/// Materialize a single constant operation from a given attribute value with
48/// the desired resultant type.
49Operation *TensorDialect::materializeConstant(OpBuilder &builder,
50 Attribute value, Type type,
51 Location loc) {
52 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 return op;
54 if (complex::ConstantOp::isBuildableWith(value, type))
55 return complex::ConstantOp::create(builder, loc, type,
56 llvm::cast<ArrayAttr>(value));
57 return nullptr;
58}
59
61 int64_t dim) {
62 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
65
66 return builder.getIndexAttr(tensorType.getDimSize(dim));
67}
68
70 Location loc, Value value) {
71 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
73 for (int64_t i = 0; i < tensorType.getRank(); ++i)
74 result.push_back(getMixedSize(builder, loc, value, i));
75 return result;
76}
77
79 OpResult opResult) {
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
81 assert(tensorType && "expected tensor type");
82
83 // If the op has a destination, it implements DestinationStyleOpInterface and
84 // we can query the destination operand from that interface.
85 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
86 if (destOp)
87 return destOp.getTiedOpOperand(opResult)->get();
88
89 // Otherwise, create a new destination tensor with the same shape.
91 b.setInsertionPoint(opResult.getDefiningOp());
92
93 // Compute sizes.
95 if (!tensorType.hasStaticShape()) {
96 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
97 ReifiedRankedShapedTypeDims reifiedShapes;
98 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
99 return failure();
100 mixedSizes = reifiedShapes[opResult.getResultNumber()];
101 } else {
102 // Static shape: Take static sizes directly.
103 for (int64_t sz : tensorType.getShape())
104 mixedSizes.push_back(b.getIndexAttr(sz));
105 }
106
107 // Create empty tensor with the same encoding as the result type.
108 Attribute encoding;
109 if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType))
110 encoding = rankedTensorType.getEncoding();
111 Value emptyTensor = tensor::EmptyOp::create(
112 b, loc, mixedSizes, tensorType.getElementType(), encoding);
113 return emptyTensor;
114}
115
117 Operation *op,
119 for (OpResult opResult : op->getResults()) {
120 if (llvm::isa<TensorType>(opResult.getType())) {
121 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
122 if (failed(destination))
123 return failure();
124 result.push_back(*destination);
125 }
126 }
127 return success();
128}
129
131 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
132 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
133 return rtp1.getShape() == rtp2.getShape() &&
134 rtp1.getElementType() == rtp2.getElementType();
135 return false;
136 }
137 return tp1 == tp2; // default implementation
138}
139
140/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
141/// rank-extending tensor.insert_slice op.
142static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
143 ArrayRef<OpFoldResult> mixedSizes) {
144 llvm::SmallBitVector droppedDims(mixedSizes.size());
145 int64_t shapePos = reducedShape.size() - 1;
146
147 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
148 size_t idx = mixedSizes.size() - size.index() - 1;
149 // Rank-reduced dims must have a static unit dimension.
150 bool isStaticUnitSize =
151 isa<Attribute>(size.value()) &&
152 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
153
154 if (shapePos < 0) {
155 // There are no more dims in the reduced shape. All remaining sizes must
156 // be rank-reduced dims.
157 assert(isStaticUnitSize && "expected unit dim");
158 droppedDims.set(idx);
159 continue;
160 }
161
162 // Dim is preserved if the size is not a static 1.
163 if (!isStaticUnitSize) {
164 --shapePos;
165 continue;
166 }
167
168 // Dim is preserved if the reduced shape dim is also 1.
169 if (reducedShape[shapePos] == 1) {
170 --shapePos;
171 continue;
172 }
173
174 // Otherwise: Dim is dropped.
175 droppedDims.set(idx);
176 }
177
178 assert(shapePos < 0 && "dimension mismatch");
179 return droppedDims;
180}
181
182/// Given a ranked tensor type and a range of values that defines its dynamic
183/// dimension sizes, turn all dynamic sizes that have a constant value into
184/// static dimension sizes.
185static RankedTensorType
186foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
187 SmallVector<Value> &foldedDynamicSizes) {
188 SmallVector<int64_t> staticShape(type.getShape());
189 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
190 "incorrect number of dynamic sizes");
191
192 // Compute new static and dynamic sizes.
193 unsigned ctr = 0;
194 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
195 if (type.isDynamicDim(i)) {
196 Value dynamicSize = dynamicSizes[ctr++];
197 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
198 if (cst.has_value()) {
199 // Dynamic size must be non-negative.
200 if (cst.value() < 0) {
201 foldedDynamicSizes.push_back(dynamicSize);
202 continue;
203 }
204 staticShape[i] = *cst;
205 } else {
206 foldedDynamicSizes.push_back(dynamicSize);
207 }
208 }
209 }
210
211 return RankedTensorType::get(staticShape, type.getElementType(),
212 type.getEncoding());
213}
214
215//===----------------------------------------------------------------------===//
216// BitcastOp
217//===----------------------------------------------------------------------===//
218
219bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
220 if (inputs.size() != 1 || outputs.size() != 1)
221 return false;
222 Type a = inputs.front(), b = outputs.front();
223 auto aT = dyn_cast<TensorType>(a);
224 auto bT = dyn_cast<TensorType>(b);
225 if (!aT || !bT)
226 return false;
227
228 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
229 return false;
230
231 return succeeded(verifyCompatibleShape(aT, bT));
232}
233
234namespace {
235
236/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
237/// operation.
238struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
239 using OpRewritePattern<BitcastOp>::OpRewritePattern;
240
241 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
242 PatternRewriter &rewriter) const final {
243 auto tensorBitcastOperand =
244 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
245 if (!tensorBitcastOperand)
246 return failure();
247
248 auto resultType = cast<TensorType>(tensorBitcast.getType());
249 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
250 tensorBitcastOperand.getOperand());
251 return success();
252 }
253};
254
255} // namespace
256
257void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
258 MLIRContext *context) {
259 results.add<ChainedTensorBitcast>(context);
260}
261
262//===----------------------------------------------------------------------===//
263// CastOp
264//===----------------------------------------------------------------------===//
265
266void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
267 setNameFn(getResult(), "cast");
268}
269
270/// Returns true if `target` is a ranked tensor type that preserves static
271/// information available in the `source` ranked tensor type.
273 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
274 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
275
276 // Requires RankedTensorType.
277 if (!sourceType || !targetType)
278 return false;
279
280 // Requires same elemental type.
281 if (sourceType.getElementType() != targetType.getElementType())
282 return false;
283
284 // Requires same rank.
285 if (sourceType.getRank() != targetType.getRank())
286 return false;
287
288 // Requires same encoding.
289 if (sourceType.getEncoding() != targetType.getEncoding())
290 return false;
291
292 // If cast is towards more static sizes along any dimension, don't fold.
293 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
294 if (ShapedType::isStatic(std::get<0>(t)) &&
295 ShapedType::isDynamic(std::get<1>(t)))
296 return false;
297 }
298
299 return true;
300}
301
302/// Determines whether tensor::CastOp casts to a more dynamic version of the
303/// source tensor. This is useful to fold a tensor.cast into a consuming op and
304/// implement canonicalization patterns for ops in different dialects that may
305/// consume the results of tensor.cast operations. Such foldable tensor.cast
306/// operations are typically inserted as `slice` ops and are canonicalized,
307/// to preserve the type compatibility of their uses.
308///
309/// Returns true when all conditions are met:
310/// 1. source and result are ranked tensors with same element type and rank.
311/// 2. the tensor type has more static information than the result
312///
313/// Example:
314/// ```mlir
315/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
316/// %2 = consumer %1 ... : tensor<?x?xf32> ...
317/// ```
318///
319/// folds into:
320///
321/// ```mlir
322/// %2 = consumer %0 ... : tensor<8x16xf32> ...
323/// ```
325 if (!castOp)
326 return false;
327
328 // Can fold if the source of cast has at least as much static information as
329 // its results.
330 return preservesStaticInformation(castOp.getType(),
331 castOp.getSource().getType());
332}
333
334/// Determines whether the tensor::CastOp casts to a more static version of the
335/// source tensor. This is useful to fold into a producing op and implement
336/// canonicalization patterns with the `tensor.cast` op as the root, but
337/// producer being from different dialects. Returns true when all conditions are
338/// met:
339/// 1. source and result and ranked tensors with same element type and rank.
340/// 2. the result type has more static information than the source.
341///
342/// Example:
343/// ```mlir
344/// %1 = producer ... : tensor<?x?xf32>
345/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
346/// ```
347///
348/// can be canonicalized to :
349///
350/// ```mlir
351/// %2 = producer ... : tensor<8x16xf32>
352/// ```
353/// Not all ops might be canonicalizable this way, but for those that can be,
354/// this method provides a check that it is worth doing the canonicalization.
356 if (!castOp)
357 return false;
358 return preservesStaticInformation(castOp.getSource().getType(),
359 castOp.getType());
360}
361
363 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
364 if (llvm::isa<BlockArgument>(opOperand.get()))
365 return false;
366 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
367 return castOp && canFoldIntoConsumerOp(castOp);
368 });
369}
370
372 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
373 SmallVector<Value> newOperands;
374 newOperands.reserve(op->getNumOperands());
375
376 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
377
378 // Assumes that the result has dpsInits followed by nonDpsInits.
379 int64_t dpsInitIdx = 0;
380 for (OpOperand &opOperand : op->getOpOperands()) {
381 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
382 bool fold = canFoldIntoConsumerOp(tensorCastOp);
383 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
384 if (op.isDpsInit(&opOperand) &&
385 !llvm::isa<MemRefType>(newOperands.back().getType()))
386 newResTy[dpsInitIdx++] = newOperands.back().getType();
387 }
388 return newOperands;
389}
390
391/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
392/// that can be folded.
394 bool folded = false;
395 for (OpOperand &operand : op->getOpOperands()) {
396 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
397 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
398 operand.set(castOp.getOperand());
399 folded = true;
400 }
401 }
402 return success(folded);
403}
404
405bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
406 if (inputs.size() != 1 || outputs.size() != 1)
407 return false;
408 Type a = inputs.front(), b = outputs.front();
409 auto aT = llvm::dyn_cast<TensorType>(a);
410 auto bT = llvm::dyn_cast<TensorType>(b);
411 if (!aT || !bT)
412 return false;
413
414 if (aT.getElementType() != bT.getElementType())
415 return false;
416
417 return succeeded(verifyCompatibleShape(aT, bT));
418}
419
420/// Compute a TensorType that has the joined shape knowledge of the two
421/// given TensorTypes. The element types need to match.
423 assert(one.getElementType() == two.getElementType());
424
425 if (!one.hasRank())
426 return two;
427 if (!two.hasRank())
428 return one;
429
430 int64_t rank = one.getRank();
431 if (rank != two.getRank())
432 return {};
433
435 join.reserve(rank);
436 for (int64_t i = 0; i < rank; ++i) {
437 if (one.isDynamicDim(i)) {
438 join.push_back(two.getDimSize(i));
439 continue;
440 }
441 if (two.isDynamicDim(i)) {
442 join.push_back(one.getDimSize(i));
443 continue;
444 }
445 if (one.getDimSize(i) != two.getDimSize(i))
446 return {};
447 join.push_back(one.getDimSize(i));
448 }
449 return RankedTensorType::get(join, one.getElementType());
450}
451
452namespace {
453
454/// Replaces chains of two tensor.cast operations by a single tensor.cast
455/// operation if doing so does not remove runtime constraints.
456struct ChainedTensorCast : public OpRewritePattern<CastOp> {
457 using OpRewritePattern<CastOp>::OpRewritePattern;
458
459 LogicalResult matchAndRewrite(CastOp tensorCast,
460 PatternRewriter &rewriter) const final {
461 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
462
463 if (!tensorCastOperand)
464 return failure();
465
466 auto sourceType =
467 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
468 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
469 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
470
471 // We can remove the intermediate cast if joining all three produces the
472 // same result as just joining the source and result shapes.
473 auto firstJoin =
474 joinShapes(joinShapes(sourceType, intermediateType), resultType);
475
476 // The join might not exist if the cast sequence would fail at runtime.
477 if (!firstJoin)
478 return failure();
479
480 // The newJoin always exists if the above join exists, it might just contain
481 // less information. If so, we cannot drop the intermediate cast, as doing
482 // so would remove runtime checks.
483 auto newJoin = joinShapes(sourceType, resultType);
484 if (firstJoin != newJoin)
485 return failure();
486
487 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
488 tensorCastOperand.getOperand());
489 return success();
490 }
491};
492
493/// Fold tensor.cast into tesor.extract_slice producer.
494/// Example:
495/// ```
496/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
497/// tensor<128x512xf32> to tensor<?x512xf32>
498/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
499/// ```
500/// ->
501/// ```
502/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
503/// tensor<128x512xf32> to tensor<16x512xf32>
504/// ```
505struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
506 using OpRewritePattern<CastOp>::OpRewritePattern;
507
508 LogicalResult matchAndRewrite(CastOp tensorCast,
509 PatternRewriter &rewriter) const final {
510 auto extractOperand =
511 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
512
513 // Cannot fold cast to unranked tensor.
514 auto rankedResultType =
515 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
516 if (!rankedResultType)
517 return failure();
518
519 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
520 rankedResultType.getShape() ==
521 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
522 .getShape())
523 return failure();
524
525 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
526 auto dimMask = computeRankReductionMask(
527 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
528 size_t dimIndex = 0;
529 for (size_t i = 0, e = sizes.size(); i < e; i++) {
530 if (dimMask && dimMask->count(i))
531 continue;
532 int64_t dim = rankedResultType.getShape()[dimIndex++];
533 if (ShapedType::isDynamic(dim))
534 continue;
535 sizes[i] = rewriter.getIndexAttr(dim);
536 }
537
538 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
539 tensorCast, rankedResultType, extractOperand.getSource(),
540 extractOperand.getMixedOffsets(), sizes,
541 extractOperand.getMixedStrides());
542 return success();
543 }
544};
545
546} // namespace
547
548void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
549 MLIRContext *context) {
550 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
551}
552
553//===----------------------------------------------------------------------===//
554// ConcatOp
555//===----------------------------------------------------------------------===//
556
557RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
558 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
559 auto tensorTypes =
560 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
561 int64_t concatRank = tensorTypes[0].getRank();
562
563 // The concatenation dim must be in the range [0, rank).
564 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
565
566 SmallVector<int64_t> sizes(concatRank);
567 for (int64_t i = 0, e = concatRank; i < e; ++i) {
568 if (i == dim)
569 continue;
570 SaturatedInteger size;
571 for (auto tensorType : tensorTypes)
572 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
573 sizes[i] = size.asInteger();
574 }
575 auto concatSize = SaturatedInteger::wrap(0);
576 for (auto tensorType : tensorTypes)
577 concatSize =
578 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
579 sizes[dim] = concatSize.asInteger();
580 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
581}
582
583void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
584 ValueRange inputs) {
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputs.getTypes());
587 assert(succeeded(resultType) && "failed to infer concatenation result type");
588 build(builder, result, *resultType, dim, inputs);
589}
590
591LogicalResult ConcatOp::verify() {
592 if (getInputs().size() < 1)
593 return emitOpError("requires at least one input");
594
596 for (auto input : getInputs())
597 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
598
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
603 }))
604 return emitOpError("rank of concatenated inputs must match result rank");
605
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
609 }))
610 return emitOpError("inputs and result element type must match");
611
612 int64_t dim = getDim();
613 if (dim >= resultRank)
614 return emitOpError("concatenation dim must be less than the tensor rank");
615
616 SmallVector<int64_t> sizes(resultRank);
617 for (int64_t i = 0, e = resultRank; i < e; ++i) {
618 if (i == dim)
619 continue;
620 SaturatedInteger size;
621 for (auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
623 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
624 if (failed(maybeSize))
625 return emitOpError("static concatenation size mismatch along ")
626 << "non-concatenated dimension " << i;
627 size = *maybeSize;
628 }
629 sizes[i] = size.asInteger();
630 }
631 auto concatSize = SaturatedInteger::wrap(0);
632 for (auto tensorType : inputTypes)
633 concatSize =
634 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
637 RankedTensorType::get(sizes, inputTypes[0].getElementType());
638
639 for (auto [inferredSize, actualSize] :
640 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642 ShapedType::isDynamic(actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
644 return emitOpError("result type ")
645 << resultType << "does not match inferred shape "
646 << inferredResultType << " static sizes";
647 }
648
649 return success();
650}
651
652FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
655
657 inputShapes.reserve(numInputs);
658 SmallVector<OpFoldResult> concatOffsets;
659 concatOffsets.reserve(numInputs);
660 SmallVector<OpFoldResult> outputShape;
661
662 AffineExpr addExpr =
663 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
664 OpFoldResult zero = builder.getIndexAttr(0);
665 Location loc = getLoc();
666 for (auto [index, input] : llvm::enumerate(getInputs())) {
667 SmallVector<OpFoldResult> inputShape =
668 tensor::getMixedSizes(builder, input.getLoc(), input);
669 if (index == 0) {
670 outputShape = inputShape;
671 concatOffsets.push_back(zero);
672 } else {
673 concatOffsets.push_back(outputShape[concatDim]);
674 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
675 builder, loc, addExpr,
676 {outputShape[concatDim], inputShape[concatDim]});
677 }
678 inputShapes.emplace_back(std::move(inputShape));
679 }
680
681 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
683
684 int64_t rank = getType().getRank();
685 OpFoldResult one = builder.getIndexAttr(1);
686 SmallVector<OpFoldResult> strides(rank, one);
687 SmallVector<OpFoldResult> offsets(rank, zero);
688 for (auto [index, input] : llvm::enumerate(getInputs())) {
689 offsets[concatDim] = concatOffsets[index];
690 auto insertSlice = tensor::InsertSliceOp::create(
691 builder, loc, input, replacement, offsets, inputShapes[index], strides);
692 replacement = insertSlice.getResult();
693 }
694 if (replacement.getType() != getType()) {
695 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
696 }
698}
699
700LogicalResult
701ConcatOp::reifyResultShapes(OpBuilder &builder,
702 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
703 ValueRange inputs = getInputs();
704 int64_t dim = getDim();
705 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
706
707 Value init = inputs[0];
708 int64_t rank = getType().getRank();
709
710 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
711
712 // Pre-populate the result sizes with as much static information as possible
713 // from the given result type, as well as the inferred result type, otherwise
714 // use the dim sizes from the first input.
715 for (int64_t i = 0; i < rank; ++i) {
716 if (i == dim)
717 continue;
718 if (!getType().isDynamicDim(i)) {
719 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
720 } else if (!inferredResultType.isDynamicDim(i)) {
721 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
722 builder, getLoc(),
723 builder.getIndexAttr(inferredResultType.getDimSize(i)));
724 } else {
725 reifiedReturnShapes[0][i] =
726 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
727 }
728 }
729
730 if (getType().isDynamicDim(dim)) {
731 // Take the sum of the input sizes along the concatenated dim.
732 AffineExpr sum = builder.getAffineDimExpr(0);
734 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
735 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
736 sum = sum + builder.getAffineDimExpr(idx + 1);
737 sizes.push_back(
738 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
739 }
740 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
741 builder, getLoc(),
742 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
743 } else {
744 // If the result shape is static along the concatenated dim, use the static
745 // shape.
746 reifiedReturnShapes[0][dim] =
747 builder.getIndexAttr(getType().getDimSize(dim));
748 }
749 return success();
750}
751
752void ConcatOp::getAsmResultNames(
753 function_ref<void(Value, StringRef)> setNameFn) {
754 setNameFn(getResult(), "concat");
755}
756
757OpFoldResult ConcatOp::fold(FoldAdaptor) {
758 ValueRange inputs = getInputs();
759 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
760 return inputs[0];
761 return {};
762}
763
764namespace {
765/// Fold a concat op with a single input to a cast.
766struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
767 using OpRewritePattern<ConcatOp>::OpRewritePattern;
768
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
770 PatternRewriter &rewriter) const override {
771 if (concatOp.getInputs().size() != 1)
772 return failure();
773 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
774 concatOp.getInputs()[0]);
775 return success();
776 }
777};
778
779/// Propagate static shapes into the operands of a `tensor.concat`.
780///
781/// `tensor.concat` requires every operand to match on all dimensions except the
782/// concatenation dimension. If one operand is already static in those
783/// dimensions, the other operands may safely be refined to that same static
784/// shape.
785///
786/// Example:
787///
788/// ```mlir
789/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790/// tensor<?x12xi32>
791/// ```
792/// ->
793/// ```mlir
794/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795/// %2 = tensor.concat dim(0) %0, %cast :
796/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797/// ```
798struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
799 using OpRewritePattern<ConcatOp>::OpRewritePattern;
800
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
802 PatternRewriter &rewriter) const override {
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
806
807 // Find operands for which a more static shape can be inferred.
808 LogicalResult matched = failure();
809 // Inferred operand shapes are identical in every dimension except the
810 // concatenation dimension.
811 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812 for (auto [operandIdx, operandType] :
813 llvm::enumerate(concatOp->getOperandTypes())) {
814 // Compute inferred type for operand.
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(operandType).getDimSize(dim);
817 auto inferredOperandType = RankedTensorType::get(
818 inferredOperandShape, inferredResultType.getElementType());
819
820 // Check if inferred type is more static.
821 if (!preservesStaticInformation(inferredOperandType, operandType)) {
822 matched = success();
823
824 // Use refined operand type and create cast from original operand.
825 auto castOp =
826 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
827 concatOp.getOperand(operandIdx));
828 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
829 concatOp->setOperand(operandIdx, castOp->getResult(0));
830 });
831 }
832 }
833
834 return matched;
835 }
836};
837
838// Ensure `tensor.concat`'s result type is at least as static as can be inferred
839// from its operand types.
840///
841/// Example:
842/// ```mlir
843/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
844/// tensor<?x?xi32>
845/// ```
846/// ->
847/// ```mlir
848/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
849/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
850/// tensor<?x?xi32>
851/// ```
852struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
853 using OpRewritePattern<ConcatOp>::OpRewritePattern;
854
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
856 PatternRewriter &rewriter) const override {
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
860
861 // The result type should be at least as static as inferred result type.
862 if (preservesStaticInformation(inferredResultType,
863 concatOp.getResultType())) {
864 return failure();
865 }
866
867 auto newConcatOp =
868 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
869 concatOp->getOperands());
870 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
871 newConcatOp);
872
873 return success();
874 }
875};
876} // namespace
877
878void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
879 MLIRContext *context) {
880 results
881 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
882 context);
883}
884
885//===----------------------------------------------------------------------===//
886// DimOp
887//===----------------------------------------------------------------------===//
888
889void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890 setNameFn(getResult(), "dim");
891}
892
893void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894 int64_t index) {
895 auto loc = result.location;
896 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
897 build(builder, result, source, indexValue);
898}
899
900std::optional<int64_t> DimOp::getConstantIndex() {
902}
903
904Speculation::Speculatability DimOp::getSpeculatability() {
905 auto constantIndex = getConstantIndex();
906 if (!constantIndex)
908
909 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
910 if (!rankedSourceType)
912
913 if (rankedSourceType.getRank() <= constantIndex)
915
917}
918
919void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920 SetIntLatticeFn setResultRange) {
921 setResultRange(getResult(),
922 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923}
924
925OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926 // All forms of folding require a known index.
927 std::optional<int64_t> index = getConstantIndex();
928 if (!index)
929 return {};
930
931 // Folding for unranked types (UnrankedTensorType) is not supported.
932 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
933 if (!tensorType)
934 return {};
935
936 // Out of bound indices produce undefined behavior but are still valid IR.
937 // Don't choke on them.
938 int64_t indexVal = index.value();
939 if (indexVal < 0 || indexVal >= tensorType.getRank())
940 return {};
941
942 // Fold if the shape extent along the given index is known.
943 if (!tensorType.isDynamicDim(indexVal)) {
944 Builder builder(getContext());
945 return builder.getIndexAttr(tensorType.getShape()[indexVal]);
946 }
947
948 Operation *definingOp = getSource().getDefiningOp();
949
950 // Fold dim to the operand of tensor.generate.
951 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
952 auto resultType =
953 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
954 // The case where the type encodes the size of the dimension is handled
955 // above.
956 assert(ShapedType::isDynamic(resultType.getShape()[indexVal]));
957
958 // Find the operand of the fromElements that corresponds to this index.
959 auto dynExtents = fromElements.getDynamicExtents().begin();
960 for (auto dim : resultType.getShape().take_front(indexVal))
961 if (ShapedType::isDynamic(dim))
962 dynExtents++;
963
964 return Value{*dynExtents};
965 }
966
967 // The size at the given index is now known to be a dynamic size.
968 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
969 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
970 // `resolve-shaped-type-result-dims` pass.
971 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
972 sliceOp.isDynamicSize(indexVal)) {
973 return {sliceOp.getDynamicSize(indexVal)};
974 }
975 }
976
977 // dim(cast) -> dim
978 if (succeeded(foldTensorCast(*this)))
979 return getResult();
980
981 return {};
982}
983
984namespace {
985/// Fold dim of a cast into the dim of the source of the tensor cast.
986struct DimOfCastOp : public OpRewritePattern<DimOp> {
987 using OpRewritePattern<DimOp>::OpRewritePattern;
988
989 LogicalResult matchAndRewrite(DimOp dimOp,
990 PatternRewriter &rewriter) const override {
991 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
992 if (!castOp)
993 return failure();
994 Value newSource = castOp.getOperand();
995 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
996 return success();
997 }
998};
999
1000/// Fold dim of a destination passing style op into the dim of the corresponding
1001/// init.
1002struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1003 using OpRewritePattern<DimOp>::OpRewritePattern;
1004
1005 LogicalResult matchAndRewrite(DimOp dimOp,
1006 PatternRewriter &rewriter) const override {
1007 auto source = dimOp.getSource();
1008 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1009 if (!destOp)
1010 return failure();
1011
1012 auto resultIndex = cast<OpResult>(source).getResultNumber();
1013 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1014
1015 rewriter.modifyOpInPlace(
1016 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1017 return success();
1018 }
1019};
1020
1021/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1022/// operand.
1023struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1024 using OpRewritePattern<DimOp>::OpRewritePattern;
1025
1026 LogicalResult matchAndRewrite(DimOp dim,
1027 PatternRewriter &rewriter) const override {
1028 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1029
1030 if (!reshape)
1031 return failure();
1032
1033 // Since tensors are immutable we don't need to worry about where to place
1034 // the extract call
1035 rewriter.setInsertionPointAfter(dim);
1036 Location loc = dim.getLoc();
1037 Value extract =
1038 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1039 if (extract.getType() != dim.getType())
1040 extract =
1041 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1042 rewriter.replaceOp(dim, extract);
1043 return success();
1044 }
1045};
1046} // namespace
1047
1048void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1049 MLIRContext *context) {
1050 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1051}
1052
1053//===----------------------------------------------------------------------===//
1054// EmptyOp
1055//===----------------------------------------------------------------------===//
1056
1057void EmptyOp::build(OpBuilder &builder, OperationState &result,
1058 ArrayRef<int64_t> staticShape, Type elementType,
1059 Attribute encoding) {
1060 assert(none_of(staticShape, ShapedType::isDynamic) &&
1061 "expected only static sizes");
1062 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1063}
1064
1065void EmptyOp::build(OpBuilder &builder, OperationState &result,
1066 ArrayRef<int64_t> staticShape, Type elementType,
1067 ValueRange dynamicSizes, Attribute encoding) {
1068 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1069 build(builder, result, tensorType, dynamicSizes);
1070}
1071
1072void EmptyOp::build(OpBuilder &builder, OperationState &result,
1073 ArrayRef<OpFoldResult> sizes, Type elementType,
1074 Attribute encoding) {
1075 SmallVector<int64_t> staticShape;
1076 SmallVector<Value> dynamicSizes;
1077 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1078 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1079}
1080
1081LogicalResult EmptyOp::verify() {
1082 return verifyDynamicDimensionCount(getOperation(), getType(),
1083 getDynamicSizes());
1084}
1085
1086LogicalResult
1087EmptyOp::reifyResultShapes(OpBuilder &builder,
1088 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1089 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1090 unsigned ctr = 0;
1091 for (int64_t i = 0; i < getType().getRank(); ++i) {
1092 if (getType().isDynamicDim(i)) {
1093 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1094 } else {
1095 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1096 }
1097 }
1098 return success();
1099}
1100
1101Value EmptyOp::getDynamicSize(unsigned idx) {
1102 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1103 unsigned ctr = 0;
1104 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1105 if (getType().isDynamicDim(i))
1106 ++ctr;
1107 return getDynamicSizes()[ctr];
1108}
1109
1110SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1111 SmallVector<OpFoldResult> result;
1112 unsigned ctr = 0;
1113 Builder b(getContext());
1114 for (int64_t dim : getType().getShape()) {
1115 if (ShapedType::isDynamic(dim)) {
1116 result.push_back(getDynamicSizes()[ctr++]);
1117 } else {
1118 result.push_back(b.getIndexAttr(dim));
1119 }
1120 }
1121 return result;
1122}
1123
1124namespace {
1125/// Change the type of the result of a `tensor.empty` by making the result
1126/// type statically sized along dimensions that in the original operation were
1127/// defined as dynamic, but the size was defined using a `constant` op. For
1128/// example
1129///
1130/// %c5 = arith.constant 5: index
1131/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1132///
1133/// to
1134///
1135/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1136struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1137 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1138
1139 LogicalResult matchAndRewrite(EmptyOp op,
1140 PatternRewriter &rewriter) const override {
1141 SmallVector<Value> foldedDynamicSizes;
1142 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1143 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1144
1145 // Stop here if no dynamic size was promoted to static.
1146 if (foldedTensorType == op.getType())
1147 return failure();
1148
1149 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1150 foldedDynamicSizes);
1151 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1152 return success();
1153 }
1154};
1155
1156struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1157 using OpRewritePattern<DimOp>::OpRewritePattern;
1158
1159 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1160 PatternRewriter &rewriter) const override {
1161 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1162 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1163 if (!emptyTensorOp || !maybeConstantIndex)
1164 return failure();
1165 auto emptyTensorType = emptyTensorOp.getType();
1166 if (*maybeConstantIndex < 0 ||
1167 *maybeConstantIndex >= emptyTensorType.getRank() ||
1168 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1169 return failure();
1170 rewriter.replaceOp(dimOp,
1171 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1172 return success();
1173 }
1174};
1175
1176/// Canonicalize
1177///
1178/// ```mlir
1179/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1180/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1181/// ```
1182///
1183/// into
1184///
1185/// ```mlir
1186/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1187/// ```
1188///
1189/// This assumes the input program is correct in terms of its shape. So it is
1190/// safe to assume that `%d0` is in fact 4.
1191struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1192 using OpRewritePattern<CastOp>::OpRewritePattern;
1193
1194 LogicalResult matchAndRewrite(CastOp castOp,
1195 PatternRewriter &rewriter) const override {
1196 if (!canFoldIntoProducerOp(castOp))
1197 return failure();
1198 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1199 if (!producer)
1200 return failure();
1201
1202 auto resultType =
1203 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1204 ArrayRef<int64_t> resultShape = resultType.getShape();
1205 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1206 SmallVector<OpFoldResult> newMixedSizes;
1207 newMixedSizes.reserve(currMixedSizes.size());
1208 assert(resultShape.size() == currMixedSizes.size() &&
1209 "mismatch in result shape and sizes of empty op");
1210 for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
1211 // Case 1: The empty tensor dim is static. Check that the tensor cast
1212 // result dim matches.
1213 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1214 if (ShapedType::isDynamic(newDim) ||
1215 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1216 // Something is off, the cast result shape cannot be more dynamic
1217 // than the empty tensor result shape (enforced by
1218 // `canFoldIntoProducer`). Abort for now.
1219 return rewriter.notifyMatchFailure(
1220 producer, "mismatch in static value of shape of empty tensor "
1221 "result and cast result");
1222 }
1223 newMixedSizes.push_back(attr);
1224 continue;
1225 }
1226
1227 // Case 2 : The tensor cast shape is static, but empty tensor result
1228 // shape is dynamic.
1229 if (ShapedType::isStatic(newDim)) {
1230 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1231 continue;
1232 }
1233
1234 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1235 // shape is dynamic. Use the dynamic value from the empty tensor op.
1236 newMixedSizes.push_back(currDim);
1237 }
1238
1239 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1240 resultType.getElementType(),
1241 resultType.getEncoding());
1242 return success();
1243 }
1244};
1245
1246} // namespace
1247
1248void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1251 ReplaceEmptyTensorStaticShapeDims>(context);
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// ExtractOp
1256//===----------------------------------------------------------------------===//
1257
1258namespace {
1259
1260/// Canonicalizes the pattern of the form
1261///
1262/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1263/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1264///
1265/// to
1266///
1267/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1268struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1269 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1270
1271 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1272 PatternRewriter &rewriter) const final {
1273 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1274 if (!tensorCast)
1275 return failure();
1276 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1277 return failure();
1278 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1279 extract, tensorCast.getSource(), extract.getIndices());
1280 return success();
1281 }
1282};
1283
1284/// Canonicalizes the pattern of the form
1285///
1286/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1287/// tensor<12xf64>
1288/// %extracted_element = tensor.extract %val[%c10] :
1289/// tensor<12xf64>
1290///
1291/// to
1292///
1293/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1294struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1295 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1296
1297 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1298 PatternRewriter &rewriter) const final {
1299 auto collapseOp =
1300 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1301 if (!collapseOp)
1302 return failure();
1303 if (!collapseOp.getSrcType().hasStaticShape())
1304 return failure();
1305
1306 auto sourceSizes = collapseOp.getSrcType().getShape();
1307
1308 SmallVector<Value> indices(extractOp.getIndices().begin(),
1309 extractOp.getIndices().end());
1310 SmallVector<Value> sourceIndices;
1311 for (auto [index, group] :
1312 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1313 assert(!group.empty() && "association indices groups cannot be empty");
1314 auto groupSize = group.size();
1315
1316 if (groupSize == 1) {
1317 sourceIndices.push_back(index);
1318 continue;
1319 }
1320
1321 SmallVector<int64_t> basis =
1322 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1323 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1324 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1325 llvm::append_range(sourceIndices, delinearize.getResults());
1326 }
1327 if (collapseOp.getReassociationIndices().empty()) {
1328 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1329 int64_t srcRank =
1330 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1332 rewriter, extractOp.getLoc(), zeroAffineMap,
1333 ArrayRef<OpFoldResult>{});
1334 for (int64_t i = 0; i < srcRank; i++) {
1335 sourceIndices.push_back(
1336 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1337 }
1338 }
1339
1340 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1341 extractOp, collapseOp.getSrc(), sourceIndices);
1342 return success();
1343 }
1344};
1345
1346} // namespace
1347
1348void ExtractOp::getAsmResultNames(
1349 function_ref<void(Value, StringRef)> setNameFn) {
1350 setNameFn(getResult(), "extracted");
1351}
1352
1353LogicalResult ExtractOp::verify() {
1354 // Verify the # indices match if we have a ranked type.
1355 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1356 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1357 return emitOpError("incorrect number of indices for extract_element");
1358 return success();
1359}
1360
1361/// If we have an ExtractOp consuming an InsertOp with the same
1362/// indices, we can return the InsertOp's scalar directly.
1363// TODO: This only checks the immediate producer; extend to go up the
1364// insert/extract chain if the slices are disjoint.
1365static Value foldExtractAfterInsert(ExtractOp extractOp) {
1366 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1367
1368 auto isSame = [](Value a, Value b) {
1370 };
1371 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1372 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1373 return insertOp.getScalar();
1374
1375 return {};
1376}
1377
1378OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1379 if (Attribute tensor = adaptor.getTensor()) {
1380 // If this is a splat elements attribute, simply return the value.
1381 // All of the elements of a splat attribute are the same.
1382 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1383 return splatTensor.getSplatValue<Attribute>();
1384
1385 // If this is a dense resource elements attribute, return.
1386 if (isa<DenseResourceElementsAttr>(tensor))
1387 return {};
1388 }
1389
1390 // Collect the constant indices into the tensor.
1391 SmallVector<uint64_t, 8> indices;
1392 for (Attribute indice : adaptor.getIndices()) {
1393 if (!indice || !llvm::isa<IntegerAttr>(indice))
1394 return {};
1395 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1396 }
1397
1398 // Fold extract(from_elements(...)).
1399 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1400 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1401 auto rank = tensorType.getRank();
1402 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1403 "rank mismatch");
1404 int flatIndex = 0;
1405 int stride = 1;
1406 for (int i = rank - 1; i >= 0; --i) {
1407 flatIndex += indices[i] * stride;
1408 stride *= tensorType.getDimSize(i);
1409 }
1410 // Prevent out of bounds accesses. This can happen in invalid code that
1411 // will never execute.
1412 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1413 flatIndex < 0)
1414 return {};
1415 return fromElementsOp.getElements()[flatIndex];
1416 }
1417
1418 // If this is an elements attribute, query the value at the given indices.
1419 if (Attribute tensor = adaptor.getTensor()) {
1420 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1421 if (elementsAttr && elementsAttr.isValidIndex(indices))
1422 return elementsAttr.getValues<Attribute>()[indices];
1423 }
1424
1425 if (Value result = foldExtractAfterInsert(*this))
1426 return result;
1427
1428 return {};
1429}
1430
1431void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1432 MLIRContext *context) {
1433 results.add<ExtractFromTensorCast>(context);
1434}
1435
1437 RewritePatternSet &patterns) {
1438 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// FromElementsOp
1443//===----------------------------------------------------------------------===//
1444
1445void FromElementsOp::getAsmResultNames(
1446 function_ref<void(Value, StringRef)> setNameFn) {
1447 setNameFn(getResult(), "from_elements");
1448}
1449
1450void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1451 ValueRange elements) {
1452 assert(!elements.empty() && "expected at least one element");
1453 Type resultType = RankedTensorType::get(
1454 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1455 build(builder, result, resultType, elements);
1456}
1457
1458OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1459 // DenseElementsAttr::get requires StringAttr for element types that are not
1460 // integer, index, float, or complex (e.g. vector types), but folded constants
1461 // won't be StringAttr instances. Only fold for element types directly
1462 // supported by DenseElementsAttr.
1463 Type eltType = getType().getElementType();
1464 if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
1465 return {};
1466 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1467 return DenseElementsAttr::get(getType(), adaptor.getElements());
1468 return {};
1469}
1470
1471namespace {
1472
1473// Pushes the index_casts that occur before extractions to after the extract.
1474// This minimizes type conversion in some cases and enables the extract
1475// canonicalizer. This changes:
1476//
1477// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1478// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1479//
1480// to the following:
1481//
1482// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1483// %cast = arith.index_cast %extract : i32 to index
1484//
1485// to just %element.
1486//
1487// Consider expanding this to a template and handle all tensor cast
1488// operations.
1489struct ExtractElementFromIndexCast
1490 : public OpRewritePattern<tensor::ExtractOp> {
1491 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1492
1493 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1494 PatternRewriter &rewriter) const final {
1495 Location loc = extract.getLoc();
1496 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1497 if (!indexCast)
1498 return failure();
1499
1500 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1501
1502 auto newExtract = tensor::ExtractOp::create(
1503 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1504
1505 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1506 newExtract);
1507
1508 return success();
1509 }
1510};
1511
1512} // namespace
1513
1514void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1515 MLIRContext *context) {
1516 results.add<ExtractElementFromIndexCast>(context);
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// GatherOp
1521//===----------------------------------------------------------------------===//
1522
1523void GatherOp::getAsmResultNames(
1524 function_ref<void(Value, StringRef)> setNameFn) {
1525 setNameFn(getResult(), "gather");
1526}
1527
1528/// Return the inferred result type for a gatherOp where:
1529/// - sourceType is the type of the source tensor gathered from
1530/// - indicesType is the type of the indices used to gather
1531/// - gatherDims are the dims along which the gather occurs.
1532/// Return a full rank or ranked-reduced variant of the type depending on
1533/// the value of rankReduced.
1534///
1535/// The leading dimensions of the index tensor give the result tensor its
1536/// leading dimensions.
1537/// The trailing dimensions of the result tensor are obtained from the source
1538/// tensor by setting the dimensions specified in gather_dims to `1` (if
1539/// rankedReduced is false), or skipping them (otherwise).
1540RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1541 RankedTensorType indicesType,
1542 ArrayRef<int64_t> gatherDims,
1543 bool rankReduced) {
1544 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1545 resultShape.reserve(resultShape.size() + sourceType.getRank());
1546 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1547 if (llvm::binary_search(gatherDims, idx)) {
1548 if (!rankReduced)
1549 resultShape.push_back(1);
1550 continue;
1551 }
1552 resultShape.push_back(sourceType.getDimSize(idx));
1553 }
1554 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1555}
1556
1557static LogicalResult
1560 StringRef gatherOrScatter, StringRef sourceOrDest) {
1561 if (dims.empty())
1562 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1563
1564 int64_t numGatherDims = dims.size();
1565 if (numGatherDims > rank)
1566 return op->emitOpError(gatherOrScatter)
1567 << "_dims overflow " << sourceOrDest << " rank";
1568 if (indices.empty() || indices.back() != numGatherDims)
1569 return op->emitOpError(gatherOrScatter)
1570 << "_dims length must match the size of last dimension of indices";
1571 for (int64_t val : dims) {
1572 if (val < 0)
1573 return op->emitOpError(gatherOrScatter)
1574 << "_dims value must be non-negative";
1575 if (val >= rank)
1576 return op->emitOpError(gatherOrScatter)
1577 << "_dims value must be smaller than " << sourceOrDest << " rank";
1578 }
1579 for (int64_t i = 1; i < numGatherDims; ++i) {
1580 if (dims[i - 1] >= dims[i])
1581 return op->emitOpError(gatherOrScatter)
1582 << "_dims values must be strictly increasing";
1583 }
1584 return success();
1585}
1586
1587LogicalResult GatherOp::verify() {
1588 int64_t sourceRank = getSourceType().getRank();
1589 ArrayRef<int64_t> gatherDims = getGatherDims();
1590 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1591 getIndicesType().getShape(), sourceRank,
1592 "gather", "source")))
1593 return failure();
1594
1595 RankedTensorType expectedResultType = GatherOp::inferResultType(
1596 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1597 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1599 if (getResultType() != expectedResultType &&
1600 getResultType() != expectedRankReducedResultType) {
1601 return emitOpError("result type "
1602 "mismatch: "
1603 "expected ")
1604 << expectedResultType << " or its rank-reduced variant "
1605 << expectedRankReducedResultType << " (got: " << getResultType()
1606 << ")";
1607 }
1608
1609 return success();
1610}
1611
1612OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1613 if (OpFoldResult reshapedSource = reshapeConstantSource(
1614 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1615 getResult().getType()))
1616 return reshapedSource;
1617 return {};
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// InsertOp
1622//===----------------------------------------------------------------------===//
1623
1624void InsertOp::getAsmResultNames(
1625 function_ref<void(Value, StringRef)> setNameFn) {
1626 setNameFn(getResult(), "inserted");
1627}
1628
1629LogicalResult InsertOp::verify() {
1630 // Verify the # indices match if we have a ranked type.
1631 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1632 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1633 return emitOpError("incorrect number of indices");
1634 return success();
1635}
1636
1637OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1638 Attribute scalar = adaptor.getScalar();
1639 Attribute dest = adaptor.getDest();
1640 if (scalar && dest)
1641 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1642 if (scalar == splatDest.getSplatValue<Attribute>())
1643 return dest;
1644 return {};
1645}
1646
1647//===----------------------------------------------------------------------===//
1648// GenerateOp
1649//===----------------------------------------------------------------------===//
1650
1651void GenerateOp::getAsmResultNames(
1652 function_ref<void(Value, StringRef)> setNameFn) {
1653 setNameFn(getResult(), "generated");
1654}
1655
1656LogicalResult GenerateOp::reifyResultShapes(
1657 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1658 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1659 int idx = 0;
1660 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1661 if (getType().isDynamicDim(dim)) {
1662 reifiedReturnShapes[0][dim] = getOperand(idx++);
1663 } else {
1664 reifiedReturnShapes[0][dim] =
1665 builder.getIndexAttr(getType().getDimSize(dim));
1666 }
1667 }
1668 return success();
1669}
1670
1671LogicalResult GenerateOp::verify() {
1672 // Ensure that the tensor type has as many dynamic dimensions as are
1673 // specified by the operands.
1674 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1675 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1676 getOperands())))
1677 return failure();
1678 return success();
1679}
1680
1681LogicalResult GenerateOp::verifyRegions() {
1682 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1683 // Ensure that region arguments span the index space.
1684 if (!llvm::all_of(getBody().getArgumentTypes(),
1685 [](Type ty) { return ty.isIndex(); }))
1686 return emitError("all body arguments must be index");
1687 if (getBody().getNumArguments() != resultTy.getRank())
1688 return emitError("must have one body argument per input dimension");
1689
1690 // Ensure that the region yields an element of the right type.
1691 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1692
1693 if (yieldOp.getValue().getType() != resultTy.getElementType())
1694 return emitOpError(
1695 "body must be terminated with a `yield` operation of the tensor "
1696 "element type");
1697
1698 return success();
1699}
1700
1701void GenerateOp::build(
1702 OpBuilder &b, OperationState &result, Type resultTy,
1703 ValueRange dynamicExtents,
1704 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1705 build(b, result, resultTy, dynamicExtents);
1706
1707 // Build and populate body.
1708 OpBuilder::InsertionGuard guard(b);
1709 Region *bodyRegion = result.regions.front().get();
1710 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1711 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1712 SmallVector<Location, 2> argumentLocs(rank, result.location);
1713 Block *bodyBlock =
1714 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1715 bodyBuilder(b, result.location, bodyBlock->getArguments());
1716}
1717
1718namespace {
1719
1720/// Canonicalizes tensor.generate operations with a constant
1721/// operand into the equivalent operation with the operand expressed in the
1722/// result type, instead. We also insert a type cast to make sure that the
1723/// resulting IR is still well-typed.
1724struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1725 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1726
1727 LogicalResult matchAndRewrite(GenerateOp generateOp,
1728 PatternRewriter &rewriter) const final {
1729 SmallVector<Value> foldedDynamicSizes;
1730 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1731 generateOp.getType(), generateOp.getDynamicExtents(),
1732 foldedDynamicSizes);
1733
1734 // Stop here if no dynamic size was promoted to static.
1735 if (foldedTensorType == generateOp.getType())
1736 return failure();
1737
1738 auto loc = generateOp.getLoc();
1739 auto newOp =
1740 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1741 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1742 newOp.getBody().begin());
1743 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1744 generateOp.getType(), newOp);
1745 return success();
1746 }
1747};
1748
1749/// Canonicalizes the pattern of the form
1750///
1751/// %tensor = tensor.generate %x {
1752/// ^bb0(%arg0: index):
1753/// <computation>
1754/// yield %1 : index
1755/// } : tensor<?xindex>
1756/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1757///
1758/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1759/// tensor.generate operation has no side-effects.
1760struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1761 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1762
1763 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1764 PatternRewriter &rewriter) const final {
1765 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1766 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1767 return failure();
1768
1769 IRMapping mapping;
1770 Block *body = &tensorFromElements.getBody().front();
1771 mapping.map(body->getArguments(), extract.getIndices());
1772 for (auto &op : body->without_terminator())
1773 rewriter.clone(op, mapping);
1774
1775 auto yield = cast<YieldOp>(body->getTerminator());
1776
1777 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1778 return success();
1779 }
1780};
1781
1782} // namespace
1783
1784void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785 MLIRContext *context) {
1786 // TODO: Move extract pattern to tensor::ExtractOp.
1787 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// RankOp
1792//===----------------------------------------------------------------------===//
1793
1794void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1795 setNameFn(getResult(), "rank");
1796}
1797
1798OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1799 // Constant fold rank when the rank of the operand is known.
1800 auto type = getOperand().getType();
1801 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802 if (shapedType && shapedType.hasRank())
1803 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1804 return IntegerAttr();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// ReshapeOp
1809//===----------------------------------------------------------------------===//
1810
1811void ReshapeOp::getAsmResultNames(
1812 function_ref<void(Value, StringRef)> setNameFn) {
1813 setNameFn(getResult(), "reshape");
1814}
1815
1816static int64_t getNumElements(ShapedType type) {
1817 int64_t numElements = 1;
1818 for (auto dim : type.getShape())
1819 numElements *= dim;
1820 return numElements;
1821}
1822
1823LogicalResult ReshapeOp::verify() {
1824 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1825 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1826
1827 if (operandType.getElementType() != resultType.getElementType())
1828 return emitOpError("element types of source and destination tensor "
1829 "types should be the same");
1830
1831 int64_t shapeSize =
1832 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1833 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1834 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1835
1836 if (resultRankedType) {
1837 if (operandRankedType && resultRankedType.hasStaticShape() &&
1838 operandRankedType.hasStaticShape()) {
1839 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1840 return emitOpError("source and destination tensor should have the "
1841 "same number of elements");
1842 }
1843 if (ShapedType::isDynamic(shapeSize))
1844 return emitOpError("cannot use shape operand with dynamic length to "
1845 "reshape to statically-ranked tensor type");
1846 if (shapeSize != resultRankedType.getRank())
1847 return emitOpError(
1848 "length of shape operand differs from the result's tensor rank");
1849 }
1850 return success();
1851}
1852
1853OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1854 if (OpFoldResult reshapedSource = reshapeConstantSource(
1855 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1856 getResult().getType()))
1857 return reshapedSource;
1858
1859 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1860 // producer's input instead as the original tensor to reshape. This could
1861 // render such producer dead code.
1862 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1863 getSourceMutable().assign(reshapeOpProducer.getSource());
1864 return getResult();
1865 }
1866
1867 auto source = getSource();
1868 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1869 auto resultTy = dyn_cast<RankedTensorType>(getType());
1870 if (!sourceTy || !resultTy || sourceTy != resultTy)
1871 return {};
1872
1873 // If the source and result are both 0D or 1D tensors and have the same type,
1874 // the reshape has no effect, even if the tensor is dynamically shaped.
1875 if (sourceTy.getRank() <= 1)
1876 return source;
1877
1878 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1879 auto elements = fromElements.getElements();
1880 bool dynamicNoop =
1881 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1882 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1883 auto element = elements[id];
1884
1885 if (auto cst = getConstantIntValue(element)) {
1886 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1887 continue;
1888 }
1889
1890 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1891 dynamicNoop &= dimOp.getSource() == source;
1892
1893 auto cst = getConstantIntValue(dimOp.getIndex());
1894 dynamicNoop &=
1895 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1896 continue;
1897 }
1898
1899 dynamicNoop = false;
1900 break;
1901 }
1902
1903 if (dynamicNoop)
1904 return source;
1905 }
1906
1907 return {};
1908}
1909
1910//===----------------------------------------------------------------------===//
1911// Reassociative reshape ops
1912//===----------------------------------------------------------------------===//
1913
1914void CollapseShapeOp::getAsmResultNames(
1915 function_ref<void(Value, StringRef)> setNameFn) {
1916 setNameFn(getResult(), "collapsed");
1917}
1918
1919void ExpandShapeOp::getAsmResultNames(
1920 function_ref<void(Value, StringRef)> setNameFn) {
1921 setNameFn(getResult(), "expanded");
1922}
1923
1924int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1925 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1926 "invalid resultDim");
1927 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1928 if (llvm::is_contained(it.value(), resultDim))
1929 return it.index();
1930 llvm_unreachable("could not find reassociation group");
1931}
1932
1933FailureOr<SmallVector<OpFoldResult>>
1934ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1935 RankedTensorType expandedType,
1936 ArrayRef<ReassociationIndices> reassociation,
1937 ArrayRef<OpFoldResult> inputShape) {
1938 std::optional<SmallVector<OpFoldResult>> outputShape =
1939 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1940 inputShape);
1941 if (!outputShape)
1942 return failure();
1943 return *outputShape;
1944}
1945
1946SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1947 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1948}
1949
1950void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1951 Type resultType, Value src,
1952 ArrayRef<ReassociationIndices> reassociation,
1953 ArrayRef<OpFoldResult> outputShape) {
1954 auto [staticOutputShape, dynamicOutputShape] =
1955 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1956 build(builder, result, cast<RankedTensorType>(resultType), src,
1957 getReassociationIndicesAttribute(builder, reassociation),
1958 dynamicOutputShape, staticOutputShape);
1959}
1960
1961void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1962 Type resultType, Value src,
1963 ArrayRef<ReassociationIndices> reassociation) {
1964 SmallVector<OpFoldResult> inputShape =
1965 getMixedSizes(builder, result.location, src);
1966 auto tensorResultTy = cast<RankedTensorType>(resultType);
1967 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1968 builder, result.location, tensorResultTy, reassociation, inputShape);
1969 SmallVector<OpFoldResult> outputShapeOrEmpty;
1970 if (succeeded(outputShape)) {
1971 outputShapeOrEmpty = *outputShape;
1972 }
1973 build(builder, result, tensorResultTy, src, reassociation,
1974 outputShapeOrEmpty);
1975}
1976
1977SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1978 return getSymbolLessAffineMaps(getReassociationExprs());
1979}
1980SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1982 getReassociationIndices());
1983}
1984
1985SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1986 return getSymbolLessAffineMaps(getReassociationExprs());
1987}
1988SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1990 getReassociationIndices());
1991}
1992
1993RankedTensorType CollapseShapeOp::inferCollapsedType(
1994 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1995 return inferCollapsedType(
1997 type.getContext(), reassociation)));
1998}
1999
2000/// Compute the RankedTensorType obtained by applying `reassociation` to
2001/// `type`.
2002RankedTensorType
2003CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2004 ArrayRef<AffineMap> reassociation) {
2005 auto shape = type.getShape();
2006 SmallVector<int64_t, 4> newShape;
2007 newShape.reserve(reassociation.size());
2008
2009 // Use the fact that reassociation is valid to simplify the logic: only use
2010 // each map's rank.
2011 assert(isReassociationValid(reassociation) && "invalid reassociation");
2012 unsigned currentDim = 0;
2013 for (AffineMap m : reassociation) {
2014 unsigned dim = m.getNumResults();
2015 auto band = shape.slice(currentDim, dim);
2016 int64_t size = 1;
2017 if (llvm::is_contained(band, ShapedType::kDynamic))
2018 size = ShapedType::kDynamic;
2019 else
2020 for (unsigned d = 0; d < dim; ++d)
2021 size *= shape[currentDim + d];
2022 newShape.push_back(size);
2023 currentDim += dim;
2024 }
2025
2026 return RankedTensorType::get(newShape, type.getElementType());
2027}
2028
2029void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2030 ArrayRef<ReassociationIndices> reassociation,
2031 ArrayRef<NamedAttribute> attrs) {
2032 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2033 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2034 auto resultType =
2035 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2036 srcType.getEncoding());
2037 result.addAttribute(getReassociationAttrStrName(),
2038 getReassociationIndicesAttribute(b, reassociation));
2039 build(b, result, resultType, src, attrs);
2040}
2041
2042template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2043 TensorReshapeOp, ExpandShapeOp>::value>
2044static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2045 RankedTensorType expandedType,
2046 RankedTensorType collapsedType) {
2047 if (failed(
2048 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2049 return failure();
2050
2051 // Reshape must preserve the number of elements when statically known.
2052 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2053 int64_t expandedNumElements = expandedType.getNumElements();
2054 int64_t collapsedNumElements = collapsedType.getNumElements();
2055 if (expandedNumElements != collapsedNumElements) {
2056 return op.emitOpError("number of elements must be preserved: ")
2057 << expandedNumElements << " != " << collapsedNumElements;
2058 }
2059 }
2060
2061 auto maps = op.getReassociationMaps();
2062 RankedTensorType expectedType =
2063 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2064 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2065 return op.emitOpError("expected collapsed type to be ")
2066 << expectedType << ", but got " << collapsedType;
2067 return success();
2068}
2069
2070LogicalResult ExpandShapeOp::verify() {
2071 RankedTensorType srcType = getSrc().getType();
2072 RankedTensorType resultType = getResult().getType();
2073
2074 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2075 return emitOpError("expected number of static shape dims to be equal to "
2076 "the output rank (")
2077 << resultType.getRank() << ") but found "
2078 << getStaticOutputShape().size() << " inputs instead";
2079
2080 if ((int64_t)getOutputShape().size() !=
2081 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2082 return emitOpError("mismatch in dynamic dims in output_shape and "
2083 "static_output_shape: static_output_shape has ")
2084 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2085 << " dynamic dims while output_shape has " << getOutputShape().size()
2086 << " values";
2087
2088 // Verify that the number of dynamic dims in output_shape matches the number
2089 // of dynamic dims in the result type.
2090 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2091 getOutputShape())))
2092 return failure();
2093
2094 // Verify if provided output shapes are in agreement with output type.
2095 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2096 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2097 for (auto [pos, shape] : llvm::enumerate(resShape))
2098 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2099 return emitOpError("invalid output shape provided at pos ") << pos;
2100
2101 return verifyTensorReshapeOp(*this, resultType, srcType);
2102}
2103
2104LogicalResult CollapseShapeOp::verify() {
2105 CollapseShapeOp op = *this;
2106 if (llvm::any_of(op.getReassociationIndices(),
2107 [](ReassociationIndices group) { return group.empty(); })) {
2108 return op.emitOpError("reassociation indices must not be empty");
2109 }
2110 RankedTensorType srcType = op.getSrc().getType();
2111 RankedTensorType resultType = op.getResult().getType();
2112
2113 return verifyTensorReshapeOp(op, srcType, resultType);
2114}
2115
2116namespace {
2117/// Reshape of a splat constant can be replaced with a constant of the result
2118/// type.
2119template <typename TensorReshapeOp>
2120struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2121 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2122 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2123 PatternRewriter &rewriter) const override {
2124 DenseElementsAttr attr;
2125 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2126 return failure();
2127 if (!attr || !attr.isSplat())
2128 return failure();
2129 // DenseElementsAttr requires a static shape; skip folding for dynamic
2130 // result types.
2131 if (!reshapeOp.getResultType().hasStaticShape())
2132 return failure();
2133 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2134 reshapeOp.getResultType(), attr.getRawData());
2135 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2136 return success();
2137 }
2138};
2139
2140// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2141template <typename TensorReshapeOp>
2142class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2143public:
2144 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2145
2146 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2147 PatternRewriter &rewriter) const override {
2148 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2149 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2150 return failure();
2151
2152 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2153 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2154 return success();
2155 }
2156};
2157
2158/// Reshape of a FromElements can be replaced with a FromElements of the
2159/// result type
2160template <typename TensorReshapeOp>
2161struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2162 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2163 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2164 PatternRewriter &rewriter) const override {
2165 auto fromElements =
2166 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2167 if (!fromElements)
2168 return failure();
2169
2170 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2171
2172 if (!shapedTy.hasStaticShape())
2173 return failure();
2174
2175 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2176 fromElements.getElements());
2177 return success();
2178 }
2179};
2180
2181// Fold CastOp into CollapseShapeOp when adding static information.
2182struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2183 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2184
2185 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2186 PatternRewriter &rewriter) const override {
2187 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2188 if (!tensor::canFoldIntoConsumerOp(castOp))
2189 return failure();
2190
2191 RankedTensorType srcType =
2192 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2193 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2194 srcType, collapseShapeOp.getReassociationMaps());
2195
2196 if (newResultType == collapseShapeOp.getResultType()) {
2197 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2198 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2199 });
2200 } else {
2201 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2202 newResultType, castOp.getSource(),
2203 collapseShapeOp.getReassociation());
2204 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2205 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2206 }
2207 return success();
2208 }
2209};
2210
2211/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2212/// matching constant output_shape operands of the expand. This makes the
2213/// `tensor.expand_shape` more static and creates a consumer cast that can be
2214/// propagated further.
2215struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2216 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2217
2218 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2219 PatternRewriter &rewriter) const override {
2220 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2221 if (!canFoldIntoConsumerOp(castOp))
2222 return failure();
2223
2224 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2225 SmallVector<ReassociationIndices, 4> reassoc =
2226 expandOp.getReassociationIndices();
2227
2228 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2229 SmallVector<Value> dynamicOutputShape;
2230 auto outputIt = expandOp.getOutputShape().begin();
2231
2232 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2233 for (uint64_t outDim : innerReassoc) {
2234 if (ShapedType::isStatic(newOutputShape[outDim]))
2235 continue;
2236
2237 // If the cast's src type is dynamic, don't infer any of the
2238 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2239 // least one of the expanded dimensions to be dynamic if the input is
2240 // dynamic.
2241 Value val = *outputIt;
2242 ++outputIt;
2243 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2244 dynamicOutputShape.push_back(val);
2245 continue;
2246 }
2247
2248 APInt cst;
2249 if (matchPattern(val, m_ConstantInt(&cst))) {
2250 newOutputShape[outDim] = cst.getSExtValue();
2251 } else {
2252 dynamicOutputShape.push_back(val);
2253 }
2254 }
2255 }
2256
2257 // Couldn't match any values, nothing to change
2258 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2259 return failure();
2260
2261 // Calculate the input shape from the output
2262 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2263 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2264 for (auto outDim : reassoc[inDim]) {
2265 auto ofr = newOutputShape[outDim];
2266 if (ShapedType::isDynamic(ofr)) {
2267 newInputShape[inDim] = ShapedType::kDynamic;
2268 break;
2269 }
2270 newInputShape[inDim] *= ofr;
2271 }
2272 }
2273
2274 SmallVector<OpFoldResult> outputOfr =
2275 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2276 auto inputType = RankedTensorType::get(
2277 newInputShape, expandOp.getSrcType().getElementType());
2278 auto outputType = RankedTensorType::get(
2279 newOutputShape, expandOp.getSrcType().getElementType());
2280 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2281 expandOp.getSrc());
2282 auto newExpand = ExpandShapeOp::create(
2283 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2284 expandOp.getReassociationIndices(), outputOfr);
2285 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2286 newExpand.getResult());
2287 return success();
2288 }
2289};
2290} // namespace
2291
2292void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2293 MLIRContext *context) {
2294 results.add<
2295 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2296 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2297 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2298 FoldReshapeWithSplat<ExpandShapeOp>,
2299 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2300}
2301
2302void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2303 MLIRContext *context) {
2304 results.add<
2305 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2306 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2307 tensor::DimOp, RankedTensorType>,
2308 FoldReshapeWithConstant<CollapseShapeOp>,
2309 FoldReshapeWithSplat<CollapseShapeOp>,
2310 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2311 context);
2312}
2313
2314OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2316 adaptor.getOperands());
2317}
2318
2319OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2321 adaptor.getOperands());
2322}
2323
2324//===----------------------------------------------------------------------===//
2325// ExtractSliceOp
2326//===----------------------------------------------------------------------===//
2327
2328void ExtractSliceOp::getAsmResultNames(
2329 function_ref<void(Value, StringRef)> setNameFn) {
2330 setNameFn(getResult(), "extracted_slice");
2331}
2332
2333/// An extract_slice result type can be inferred, when it is not
2334/// rank-reduced, from the source type and the static representation of
2335/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2336RankedTensorType
2337ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2338 ArrayRef<int64_t> staticSizes) {
2339 // An extract_slice op may specify only a leading subset of offset/sizes/
2340 // strides in which case we complete with offset=0, sizes from memref type
2341 // and strides=1.
2342 assert(static_cast<int64_t>(staticSizes.size()) ==
2343 sourceTensorType.getRank() &&
2344 "unexpected staticSizes not equal to rank of source");
2345 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2346 sourceTensorType.getEncoding());
2347}
2348
2349// TODO: This uses neither offsets nor strides!
2350RankedTensorType
2351ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2352 ArrayRef<OpFoldResult> sizes) {
2353 SmallVector<int64_t> staticSizes;
2354 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2355
2356 assert(static_cast<int64_t>(staticSizes.size()) ==
2357 sourceTensorType.getRank() &&
2358 "unexpected staticSizes not equal to rank of source");
2359 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2360 sourceTensorType.getEncoding());
2361}
2362
2363/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2364/// number of sizes), drop as many size 1 as needed to produce an inferred
2365/// type with the desired rank.
2366///
2367/// Note that there may be multiple ways to compute this rank-reduced type:
2368/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2369///
2370/// To disambiguate, this function always drops the first 1 sizes occurrences.
2371RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2372 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2373 ArrayRef<int64_t> sizes) {
2374 // Type inferred in the absence of rank-reducing behavior.
2375 auto inferredType = llvm::cast<RankedTensorType>(
2376 inferResultType(sourceRankedTensorType, sizes));
2377 int rankDiff = inferredType.getRank() - desiredResultRank;
2378 if (rankDiff > 0) {
2379 auto shape = inferredType.getShape();
2380 llvm::SmallBitVector dimsToProject =
2381 getPositionsOfShapeOne(rankDiff, shape);
2382 SmallVector<int64_t> projectedShape;
2383 // Best effort rank-reducing: drop 1s in order.
2384 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2385 if (!dimsToProject.test(pos))
2386 projectedShape.push_back(shape[pos]);
2387 inferredType =
2388 RankedTensorType::get(projectedShape, inferredType.getElementType());
2389 }
2390 return inferredType;
2391}
2392
2393RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2394 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2395 ArrayRef<OpFoldResult> sizes) {
2396 SmallVector<int64_t> staticSizes;
2397 SmallVector<Value> dynamicSizes;
2398 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2399 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2400 desiredResultRank, sourceRankedTensorType, staticSizes);
2401}
2402
2403/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2404/// result type. If the type passed is nullptr, it is inferred.
2405void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2406 RankedTensorType resultType, Value source,
2407 ArrayRef<OpFoldResult> offsets,
2408 ArrayRef<OpFoldResult> sizes,
2409 ArrayRef<OpFoldResult> strides,
2410 ArrayRef<NamedAttribute> attrs) {
2411 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2412 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2413 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2414 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2415 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2416 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2417 // Structuring implementation this way avoids duplication between builders.
2418 if (!resultType) {
2419 resultType = llvm::cast<RankedTensorType>(
2420 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2421 }
2422 result.addAttributes(attrs);
2423 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2424 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2425 b.getDenseI64ArrayAttr(staticSizes),
2426 b.getDenseI64ArrayAttr(staticStrides));
2427}
2428
2429/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2430/// result type.
2431void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2432 ArrayRef<OpFoldResult> offsets,
2433 ArrayRef<OpFoldResult> sizes,
2434 ArrayRef<OpFoldResult> strides,
2435 ArrayRef<NamedAttribute> attrs) {
2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2437}
2438
2439/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2440/// a Range vector.
2441void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2442 ArrayRef<Range> ranges,
2443 ArrayRef<NamedAttribute> attrs) {
2444 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2445 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2446}
2447
2448/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2449/// the type passed is nullptr, it is inferred.
2450void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2451 RankedTensorType resultType, Value source,
2452 ValueRange offsets, ValueRange sizes,
2453 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2454 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2455 offsets, [](Value v) -> OpFoldResult { return v; });
2456 SmallVector<OpFoldResult> sizeValues =
2457 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2458 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2459 strides, [](Value v) -> OpFoldResult { return v; });
2460 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2461}
2462
2463/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2464void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2465 ValueRange offsets, ValueRange sizes,
2466 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2467 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2468}
2469
2471 Operation *op,
2472 RankedTensorType expectedType) {
2473 switch (result) {
2475 return success();
2477 return op->emitError("expected rank to be smaller or equal to ")
2478 << "the other rank. ";
2480 return op->emitError("expected type to be ")
2481 << expectedType << " or a rank-reduced version. (size mismatch) ";
2483 return op->emitError("expected element type to be ")
2484 << expectedType.getElementType();
2485 default:
2486 llvm_unreachable("unexpected extract_slice op verification result");
2487 }
2488}
2489
2490/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2491/// result type, offsets set to 0 and strides set to 1.
2492void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2493 RankedTensorType resultType, Value source,
2494 ArrayRef<OpFoldResult> sizes,
2495 ArrayRef<NamedAttribute> attrs) {
2496 Attribute zeroIdxAttr = b.getIndexAttr(0);
2497 Attribute oneIdxAttr = b.getIndexAttr(1);
2498 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2499 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2500 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2501}
2502
2503/// Verifier for ExtractSliceOp.
2504LogicalResult ExtractSliceOp::verify() {
2505 RankedTensorType sourceType = getSourceType();
2506
2507 // Verify result type against inferred type.
2508 RankedTensorType expectedType =
2509 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2512 return produceSliceErrorMsg(result, *this, expectedType);
2513
2514 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2515 // to the source tensor.
2516 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2517 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2518 getStaticStrides(), /*generateErrorMessage=*/true);
2519 if (!boundsResult.isValid)
2520 return getOperation()->emitError(boundsResult.errorMessage);
2521
2522 return success();
2523}
2524
2525llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2526 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2527}
2528
2529FailureOr<Value>
2530ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2531 ArrayRef<int64_t> desiredShape) {
2532 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2533 assert(sourceTensorType && "not a ranked tensor type");
2534 auto sourceShape = sourceTensorType.getShape();
2535 if (sourceShape.equals(desiredShape))
2536 return value;
2537 auto maybeRankReductionMask =
2538 mlir::computeRankReductionMask(sourceShape, desiredShape);
2539 if (!maybeRankReductionMask)
2540 return failure();
2542 b, loc, value,
2543 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2544}
2545
2546LogicalResult ExtractSliceOp::reifyResultShapes(
2547 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2548 reifiedReturnShapes.resize(1);
2549 reifiedReturnShapes[0].reserve(getType().getRank());
2550 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2551 llvm::SmallBitVector droppedDims = getDroppedDims();
2552 for (const auto &size : enumerate(mixedSizes)) {
2553 if (droppedDims.test(size.index()))
2554 continue;
2555 reifiedReturnShapes[0].push_back(size.value());
2556 }
2557 return success();
2558}
2559
2560namespace {
2561/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2562/// This essentially pushes memref_cast past its consuming slice when
2563/// `canFoldIntoConsumerOp` is true.
2564///
2565/// Example:
2566/// ```
2567/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2568/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2569/// tensor<3x4xf32>
2570/// ```
2571/// is rewritten into:
2572/// ```
2573/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2574/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2575/// ```
2576class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2577public:
2578 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2579
2580 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2581 PatternRewriter &rewriter) const override {
2582 // Any constant operand, just return to let the constant folder kick in.
2583 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2584 return matchPattern(operand, matchConstantIndex());
2585 }))
2586 return failure();
2587
2588 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2589 if (!castOp)
2590 return failure();
2591
2592 if (!canFoldIntoConsumerOp(castOp))
2593 return failure();
2594
2595 // Pattern does not apply if the produced op would not verify.
2596 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2597 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2598 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2599 sliceOp.getStaticStrides());
2600 if (!sliceResult.isValid)
2601 return failure();
2602
2603 // Create folded extract.
2604 Location loc = sliceOp.getLoc();
2605 Value newResult = ExtractSliceOp::create(
2606 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2607 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2608 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2609 sliceOp.getStaticStrides());
2610 rewriter.replaceOp(sliceOp, newResult);
2611 return success();
2612 }
2613};
2614
2615/// Slice elements from `values` into `outValues`. `counts` represents the
2616/// numbers of elements to stride in the original values for each dimension.
2617/// The output values can be used to construct a DenseElementsAttr.
2618template <typename IterTy, typename ElemTy>
2619static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2620 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2621 ArrayRef<int64_t> strides,
2622 llvm::SmallVectorImpl<ElemTy> *outValues) {
2623 assert(offsets.size() == sizes.size());
2624 assert(offsets.size() == strides.size());
2625 if (offsets.empty())
2626 return;
2627
2628 int64_t offset = offsets.front();
2629 int64_t size = sizes.front();
2630 int64_t stride = strides.front();
2631 if (offsets.size() == 1) {
2632 for (int64_t i = 0; i < size; ++i, offset += stride)
2633 outValues->push_back(*(values + offset));
2634
2635 return;
2636 }
2637
2638 for (int64_t i = 0; i < size; ++i, offset += stride) {
2639 auto begin = values + offset * counts.front();
2640 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2641 offsets.drop_front(), sizes.drop_front(),
2642 strides.drop_front(), outValues);
2643 }
2644}
2645
2646/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2647/// folded operation might introduce more constant data; Users can control
2648/// their heuristics by the control function.
2649class ConstantOpExtractSliceFolder final
2650 : public OpRewritePattern<ExtractSliceOp> {
2651public:
2652 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2653
2654 ConstantOpExtractSliceFolder(MLIRContext *context,
2656 : OpRewritePattern<ExtractSliceOp>(context),
2657 controlFn(std::move(controlFn)) {}
2658
2659 LogicalResult matchAndRewrite(ExtractSliceOp op,
2660 PatternRewriter &rewriter) const override {
2661 DenseElementsAttr attr;
2662 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2663 return failure();
2664
2665 // A constant splat is handled by fold().
2666 if (attr.isSplat())
2667 return failure();
2668
2669 // Dynamic result shape is not supported.
2670 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2671 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2672 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2673 return failure();
2674
2675 // Customized control over the folding.
2676 if (!controlFn(op))
2677 return failure();
2678
2679 int64_t count = sourceType.getNumElements();
2680 if (count == 0)
2681 return failure();
2682
2683 // Check if there are any dynamic parts, which are not supported.
2684 auto offsets = op.getStaticOffsets();
2685 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2686 return failure();
2687 auto sizes = op.getStaticSizes();
2688 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2689 return failure();
2690 auto strides = op.getStaticStrides();
2691 if (llvm::is_contained(strides, ShapedType::kDynamic))
2692 return failure();
2693
2694 // Compute the stride for each dimension.
2695 SmallVector<int64_t> counts;
2696 ArrayRef<int64_t> shape = sourceType.getShape();
2697 counts.reserve(shape.size());
2698 for (int64_t v : shape) {
2699 count = count / v;
2700 counts.push_back(count);
2701 }
2702
2703 // Slice the elements and construct a new attribute.
2704 SmallVector<Attribute> outValues;
2705 outValues.reserve(resultType.getNumElements());
2706 sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
2707 strides, &outValues);
2708 auto newAttr = DenseElementsAttr::get(resultType, outValues);
2709 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2710 return success();
2711 }
2712
2713private:
2714 /// This additionally controls whether the fold happens or not. Users can
2715 /// impose their heuristics in the function.
2717};
2718
2719} // namespace
2720
2722 RewritePatternSet &patterns,
2723 const ControlConstantExtractSliceFusionFn &controlFn) {
2724 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2725}
2726
2727/// Return the canonical type of the result of an extract_slice op.
2729 RankedTensorType operator()(ExtractSliceOp op,
2730 ArrayRef<OpFoldResult> mixedOffsets,
2731 ArrayRef<OpFoldResult> mixedSizes,
2732 ArrayRef<OpFoldResult> mixedStrides) {
2733 // Infer a tensor type without taking into account any rank reductions.
2734 RankedTensorType nonReducedType =
2735 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2736
2737 // Directly return the non-rank reduced type if there are no dropped
2738 // dims.
2739 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2740 if (droppedDims.none())
2741 return nonReducedType;
2742
2743 // Build the reduced shape, preserving the original rank reduction pattern.
2744 SmallVector<int64_t> targetShape;
2745 for (auto i : llvm::seq<int64_t>(mixedSizes.size()))
2746 if (!droppedDims.test(i))
2747 targetShape.push_back(nonReducedType.getDimSize(i));
2748
2749 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2750 nonReducedType.getEncoding());
2751 }
2752};
2753
2754/// A canonicalizer wrapper to replace ExtractSliceOps.
2756 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2757 ExtractSliceOp newOp) {
2758 Value replacement = newOp.getResult();
2759 if (replacement.getType() != op.getType())
2760 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2761 replacement);
2762 rewriter.replaceOp(op, replacement);
2763 }
2764};
2765
2766void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2767 MLIRContext *context) {
2768 results.add<
2769 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2770 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2771 ExtractSliceOpCastFolder>(context);
2772}
2773
2774//
2775static LogicalResult
2776foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2777 ShapedType shapedType) {
2778 OpBuilder b(op.getContext());
2779 for (OpFoldResult ofr : op.getMixedOffsets())
2780 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2781 return failure();
2782 // Rank-reducing noops only need to inspect the leading dimensions:
2783 // llvm::zip is appropriate.
2784 auto shape = shapedType.getShape();
2785 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2786 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2787 return failure();
2788 for (OpFoldResult ofr : op.getMixedStrides())
2789 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2790 return failure();
2791 return success();
2792}
2793
2794/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2795/// slice, we can return the InsertSliceOp's source directly.
2796// TODO: This only checks the immediate producer; extend to go up the
2797// insert/extract chain if the slices are disjoint.
2798static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2799 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2800
2801 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2802 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2803 insertOp.isSameAs(extractOp, isSame))
2804 return insertOp.getSource();
2805
2806 return {};
2807}
2808
2809OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2810 if (OpFoldResult reshapedSource = reshapeConstantSource(
2811 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2812 getResult().getType()))
2813 return reshapedSource;
2814 if (getSourceType() == getType() &&
2816 return this->getSource();
2817 if (Value slice = foldExtractAfterInsertSlice(*this))
2818 return slice;
2819
2820 return OpFoldResult();
2821}
2822
2824 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2825 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2826 unsigned rank = rankedTensorType.getRank();
2827 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2829 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2830 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2831 offsets, sizes, strides);
2832}
2833
2834//===----------------------------------------------------------------------===//
2835// InsertSliceOp
2836//===----------------------------------------------------------------------===//
2837
2838void InsertSliceOp::getAsmResultNames(
2839 function_ref<void(Value, StringRef)> setNameFn) {
2840 setNameFn(getResult(), "inserted_slice");
2841}
2842
2843// Build a InsertSliceOp with mixed static and dynamic entries.
2844void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2845 Value dest, ArrayRef<OpFoldResult> offsets,
2847 ArrayRef<OpFoldResult> strides,
2849 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2850 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2851 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2852 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2853 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2854 result.addAttributes(attrs);
2855 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2856 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2857 b.getDenseI64ArrayAttr(staticSizes),
2858 b.getDenseI64ArrayAttr(staticStrides));
2859}
2860
2861/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2862/// Range vector.
2863void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2864 Value dest, ArrayRef<Range> ranges,
2865 ArrayRef<NamedAttribute> attrs) {
2866 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2867 build(b, result, source, dest, offsets, sizes, strides, attrs);
2868}
2869
2870// Build a InsertSliceOp with dynamic entries.
2871void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2872 Value dest, ValueRange offsets, ValueRange sizes,
2873 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2874 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2875 offsets, [](Value v) -> OpFoldResult { return v; });
2876 SmallVector<OpFoldResult> sizeValues =
2877 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2878 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2879 strides, [](Value v) -> OpFoldResult { return v; });
2880 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2881}
2882
2883/// Rank-reducing type verification for both InsertSliceOp and
2884/// ParallelInsertSliceOp.
2886 RankedTensorType srcType, RankedTensorType dstType,
2887 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2888 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2889 // insert_slice is the inverse of extract_slice, use the same type
2890 // inference.
2891 RankedTensorType expected =
2892 ExtractSliceOp::inferResultType(dstType, staticSizes);
2893 if (expectedType)
2894 *expectedType = expected;
2895 return isRankReducedType(expected, srcType);
2896}
2897
2898/// Verifier for InsertSliceOp.
2899LogicalResult InsertSliceOp::verify() {
2900 // Verify result type against inferred type.
2901 RankedTensorType expectedType;
2903 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2904 getStaticSizes(), getStaticStrides(), &expectedType);
2906 return produceSliceErrorMsg(result, *this, expectedType);
2907
2908 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2909 // to the destination tensor.
2910 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2911 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2912 getStaticStrides(), /*generateErrorMessage=*/true);
2913 if (!boundsResult.isValid)
2914 return getOperation()->emitError(boundsResult.errorMessage);
2915
2916 return success();
2917}
2918
2919/// If we have two consecutive InsertSliceOp writing to the same slice, we
2920/// can mutate the second InsertSliceOp's destination to the first one's.
2921///
2922/// Example:
2923///
2924/// ```mlir
2925/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2926/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2927/// ```
2928///
2929/// folds into:
2930///
2931/// ```mlir
2932/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2933/// ```
2934///
2935/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2936static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2937 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2938
2939 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2940 if (!prevInsertOp ||
2941 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2942 !prevInsertOp.isSameAs(insertOp, isSame))
2943 return failure();
2944
2945 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2946 return success();
2947}
2948
2949/// Folds round-trip extract/insert slice op pairs.
2950/// Example:
2951/// ```mlir
2952/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2953/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2954/// ```
2955/// can be folded into %val.
2956static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2957 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2958
2959 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2960 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2961 !extractOp.isSameAs(insertOp, isSame))
2962 return nullptr;
2963
2964 return extractOp.getSource();
2965}
2966
2967OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2968 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2969 getSourceType() == getType() &&
2971 return this->getSource();
2972 if (succeeded(foldInsertAfterInsertSlice(*this)))
2973 return getResult();
2974 if (auto result = foldInsertAfterExtractSlice(*this))
2975 return result;
2976 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2977 return getDest();
2978 return OpFoldResult();
2979}
2980
2981LogicalResult InsertSliceOp::reifyResultShapes(
2982 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2983 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2984 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2985 return success();
2986}
2987
2988namespace {
2989/// Pattern to rewrite a insert_slice op with constant arguments.
2990///
2991/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2992template <typename InsertOpTy>
2993class InsertSliceOpConstantArgumentFolder final
2994 : public OpRewritePattern<InsertOpTy> {
2995public:
2996 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2997
2998 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2999 PatternRewriter &rewriter) const override {
3000 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3001 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3002 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3003
3004 // No constant operands were folded, just return;
3005 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
3006 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
3007 failed(foldDynamicStrideList(mixedStrides)))
3008 return failure();
3009
3010 // Pattern does not apply if the produced op would not verify.
3011 SliceBoundsVerificationResult sliceResult =
3012 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
3013 mixedOffsets, mixedSizes, mixedStrides);
3014 if (!sliceResult.isValid)
3015 return failure();
3016
3017 // Create the new op in canonical form.
3018 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3019 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3020 mixedSizes);
3021 Value toInsert = insertSliceOp.getSource();
3022 if (sourceType != insertSliceOp.getSourceType()) {
3023 OpBuilder::InsertionGuard g(rewriter);
3024 // The only difference between InsertSliceOp and ParallelInsertSliceOp
3025 // is that the insertion point is just before the InParallelOp in
3026 // the parallel case.
3027 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3028 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3029 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3030 sourceType, toInsert);
3031 }
3032 rewriter.replaceOpWithNewOp<InsertOpTy>(
3033 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3034 mixedSizes, mixedStrides);
3035 return success();
3036 }
3037};
3038
3039/// Fold tensor_casts with insert_slice operations. If the source or
3040/// destination tensor is a tensor_cast that removes static type information,
3041/// the cast is folded into the insert_slice operation. E.g.:
3042///
3043/// ```mlir
3044/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3045/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3046/// ```
3047///
3048/// folds into:
3049///
3050/// ```mlir
3051/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3052/// ```
3053///
3054/// Note: When folding a cast on the destination tensor, the result of the
3055/// insert_slice operation is casted to ensure that the type of the result did
3056/// not change.
3057///
3058/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3059template <typename InsertOpTy>
3060struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3061 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3062
3063 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3064 PatternRewriter &rewriter) const override {
3065 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3066 return matchPattern(operand, matchConstantIndex());
3067 }))
3068 return failure();
3069
3070 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3071 auto castOp = v.getDefiningOp<tensor::CastOp>();
3072 if (!castOp || !canFoldIntoConsumerOp(castOp))
3073 return std::nullopt;
3074 return castOp.getSource();
3075 };
3076 std::optional<Value> sourceCastSource =
3077 getSourceOfCastOp(insertSliceOp.getSource());
3078 std::optional<Value> destCastSource =
3079 getSourceOfCastOp(insertSliceOp.getDest());
3080 if (!sourceCastSource && !destCastSource)
3081 return failure();
3082
3083 auto src =
3084 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3085 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3086 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3087 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3088 if (!srcType || !dstType)
3089 return failure();
3090
3091 // The tensor.cast source could have additional static information not seen
3092 // in the insert slice op static sizes, so we ignore dynamic dims when
3093 // computing the rank reduction mask.
3094 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3095 auto rankReductionMask = computeRankReductionMask(
3096 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3097 if (!rankReductionMask.has_value())
3098 return failure();
3099 // Replace dimensions in the insert slice op with corresponding static dims
3100 // from the cast source type. If the insert slice sizes have static dims
3101 // that are not static in the tensor.cast source (i.e., when the cast op
3102 // casts a dynamic dim to static), the dim should not be replaced, and the
3103 // pattern will fail later in `verifyInsertSliceOp`.
3104 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3105 int64_t rankReducedIdx = 0;
3106 for (auto [idx, size] : enumerate(staticSizes)) {
3107 if (!rankReductionMask.value().contains(idx) &&
3108 !srcType.isDynamicDim(rankReducedIdx)) {
3109 mixedSizes[idx] = getAsIndexOpFoldResult(
3110 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3111 size = srcType.getDimSize(rankReducedIdx++);
3112 }
3113 }
3114
3115 // Pattern does not apply if the produced op would not verify.
3116 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3117 staticSizes, insertSliceOp.getStaticStrides()) !=
3118 SliceVerificationResult::Success)
3119 return failure();
3120 SliceBoundsVerificationResult sliceResult =
3121 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3122 mixedSizes, insertSliceOp.getMixedStrides());
3123 if (!sliceResult.isValid)
3124 return failure();
3125
3126 Operation *replacement =
3127 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3128 insertSliceOp.getMixedOffsets(), mixedSizes,
3129 insertSliceOp.getMixedStrides());
3130
3131 // In the parallel case there is no result and so nothing to cast.
3132 bool isParallelInsert =
3133 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3134 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3135 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3136 insertSliceOp.getDestType(),
3137 replacement->getResult(0));
3138 }
3139 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3140 return success();
3141 }
3142};
3143
3144/// If additional static type information can be deduced from a insert_slice's
3145/// size operands, insert an explicit cast of the op's source operand. This
3146/// enables other canonicalization patterns that are matching for tensor_cast
3147/// ops such as `ForOpTensorCastFolder` in SCF.
3148///
3149/// Example:
3150///
3151/// ```mlir
3152/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3153/// : tensor<?x?xf32> into ...
3154/// ```
3155///
3156/// folds into:
3157///
3158/// ```mlir
3159/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3160/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3161/// : tensor<64x64xf32> into ...
3162/// ```
3163///
3164/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3165template <typename InsertOpTy>
3166struct InsertSliceOpSourceCastInserter final
3167 : public OpRewritePattern<InsertOpTy> {
3168 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3169
3170 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3171 PatternRewriter &rewriter) const override {
3172 RankedTensorType srcType = insertSliceOp.getSourceType();
3173 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3174 return failure();
3175 SmallVector<int64_t> newSrcShape(srcType.getShape());
3176 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3177 if (std::optional<int64_t> constInt =
3178 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3179 // Bail on invalid IR.
3180 if (*constInt < 0)
3181 return failure();
3182 newSrcShape[i] = *constInt;
3183 }
3184 }
3185 if (!hasValidSizesOffsets(newSrcShape))
3186 return failure();
3187
3188 RankedTensorType newSrcType = RankedTensorType::get(
3189 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3190 if (srcType == newSrcType ||
3191 !preservesStaticInformation(srcType, newSrcType) ||
3192 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3193 return failure();
3194
3195 // newSrcType is:
3196 // 1) Different from srcType.
3197 // 2) "More static" than srcType.
3198 // 3) Cast-compatible with srcType.
3199 // Insert the cast.
3200 OpBuilder::InsertionGuard g(rewriter);
3201 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3202 // that the insertion point is just before the InParallelOp in the
3203 // parallel case.
3204 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3205 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3206 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3207 newSrcType, insertSliceOp.getSource());
3208 rewriter.replaceOpWithNewOp<InsertOpTy>(
3209 insertSliceOp, cast, insertSliceOp.getDest(),
3210 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3211 insertSliceOp.getMixedStrides());
3212 return success();
3213 }
3214};
3215} // namespace
3216
3217llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3218 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3219}
3220
3221void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3222 MLIRContext *context) {
3223 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3224 InsertSliceOpCastFolder<InsertSliceOp>,
3225 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3226}
3227
3229 Location loc,
3230 Value tensor,
3231 Value dest) {
3232 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3233 unsigned rank = rankedTensorType.getRank();
3234 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3235 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3236 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3237 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3238 sizes, strides);
3239}
3240
3241//===----------------------------------------------------------------------===//
3242// PadOp
3243//===----------------------------------------------------------------------===//
3244
3245void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3246 setNameFn(getResult(), "padded");
3247}
3248
3249LogicalResult PadOp::verify() {
3250 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3251 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3252 auto expectedType =
3253 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3254 if (!expectedType) {
3255 return emitError("failed to infer expectedType from sourceType ")
3256 << sourceType << ", specified resultType is " << resultType;
3257 }
3258 if (resultType.getRank() != expectedType.getRank()) {
3259 return emitError("specified type ")
3260 << resultType << " does not match the inferred type "
3261 << expectedType;
3262 }
3263 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3264 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3265 continue;
3266 if (expectedType.isDynamicDim(i))
3267 continue;
3268 return emitError("specified type ")
3269 << resultType << " does not match the inferred type "
3270 << expectedType;
3271 }
3272
3273 return success();
3274}
3275
3276LogicalResult PadOp::verifyRegions() {
3277 auto &region = getRegion();
3278 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3279 Block &block = region.front();
3280 if (block.getNumArguments() != rank)
3281 return emitError("expected the block to have ") << rank << " arguments";
3282
3283 // Note: the number and type of yield values are checked in the YieldOp.
3284 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3285 if (!en.value().isIndex())
3286 return emitOpError("expected block argument ")
3287 << (en.index() + 1) << " to be an index";
3288 }
3289
3290 // Ensure that the region yields an element of the right type.
3291 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3292 if (yieldOp.getValue().getType() !=
3293 llvm::cast<ShapedType>(getType()).getElementType())
3294 return emitOpError("expected yield type to match shape element type");
3295
3296 return success();
3297}
3298
3299RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3300 ArrayRef<int64_t> staticLow,
3301 ArrayRef<int64_t> staticHigh,
3302 ArrayRef<int64_t> resultShape) {
3303 unsigned rank = sourceType.getRank();
3304 if (staticLow.size() != rank)
3305 return RankedTensorType();
3306 if (staticHigh.size() != rank)
3307 return RankedTensorType();
3308 if (!resultShape.empty() && resultShape.size() != rank)
3309 return RankedTensorType();
3310
3311 SmallVector<int64_t, 4> inferredShape;
3312 for (auto i : llvm::seq<unsigned>(0, rank)) {
3313 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3314 staticHigh[i] == ShapedType::kDynamic) {
3315 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3316 : resultShape[i]);
3317 } else {
3318 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3319 assert((resultShape.empty() || size == resultShape[i] ||
3320 resultShape[i] == ShapedType::kDynamic) &&
3321 "mismatch between inferred shape and result shape");
3322 inferredShape.push_back(size);
3323 }
3324 }
3325
3326 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3327}
3328
3329void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3330 Value source, ArrayRef<int64_t> staticLow,
3331 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3332 bool nofold, ArrayRef<NamedAttribute> attrs) {
3333 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3334 if (!resultType)
3335 resultType = inferResultType(sourceType, staticLow, staticHigh);
3336 result.addAttributes(attrs);
3337 build(b, result, resultType, source, low, high,
3338 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3339 nofold ? b.getUnitAttr() : UnitAttr());
3340}
3341
3342void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3343 Value source, ValueRange low, ValueRange high, bool nofold,
3344 ArrayRef<NamedAttribute> attrs) {
3345 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3346 unsigned rank = sourceType.getRank();
3347 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3348 build(b, result, resultType, source, staticVector, staticVector, low, high,
3349 nofold, attrs);
3350}
3351
3352void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3353 Value source, ArrayRef<OpFoldResult> low,
3354 ArrayRef<OpFoldResult> high, bool nofold,
3355 ArrayRef<NamedAttribute> attrs) {
3356 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3357 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3358 SmallVector<int64_t, 4> staticLow, staticHigh;
3359 // staticLow and staticHigh have full information of the padding config.
3360 // This will grow staticLow and staticHigh with 1 value. If the config is
3361 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3362 // value as well.
3363 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3364 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3365 if (!resultType) {
3366 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3367 }
3368 assert(llvm::isa<RankedTensorType>(resultType));
3369 result.addAttributes(attrs);
3370 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3371 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3372 nofold ? b.getUnitAttr() : UnitAttr());
3373}
3374
3375void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3376 Value source, ArrayRef<OpFoldResult> low,
3377 ArrayRef<OpFoldResult> high, Value constantPadValue,
3378 bool nofold, ArrayRef<NamedAttribute> attrs) {
3379 build(b, result, resultType, source, low, high, nofold, attrs);
3380
3381 // Add a region and a block to yield the pad value.
3382 Region *region = result.regions[0].get();
3383 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3384 Repeated<Type> blockArgTypes(sourceRank, b.getIndexType());
3385 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3386
3387 // `builder.createBlock` changes the insertion point within the block. Create
3388 // a guard to reset the insertion point of the builder after it is destroyed.
3389 OpBuilder::InsertionGuard guard(b);
3390 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3391 tensor::YieldOp::create(b, result.location, constantPadValue);
3392}
3393
3394llvm::SmallBitVector PadOp::getPaddedDims() {
3395 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3396 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3397 for (const auto &en : enumerate(paddingWidths))
3398 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3399 paddedDims.set(en.index());
3400 };
3401 extractPaddedDims(getMixedLowPad());
3402 extractPaddedDims(getMixedHighPad());
3403 return paddedDims;
3404}
3405
3406namespace {
3407// Folds tensor.pad when padding is static zeros and the attribute
3408// doesn't request otherwise.
3409struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3410 using OpRewritePattern<PadOp>::OpRewritePattern;
3411
3412 LogicalResult matchAndRewrite(PadOp padTensorOp,
3413 PatternRewriter &rewriter) const override {
3414 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3415 return failure();
3416 if (padTensorOp.getNofold())
3417 return failure();
3418 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3419 padTensorOp, padTensorOp.getResult().getType(),
3420 padTensorOp.getSource());
3421 return success();
3422 }
3423};
3424
3425// Fold CastOp into PadOp when adding static information.
3426struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3427 using OpRewritePattern<PadOp>::OpRewritePattern;
3428
3429 LogicalResult matchAndRewrite(PadOp padTensorOp,
3430 PatternRewriter &rewriter) const override {
3431 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3432 if (!tensor::canFoldIntoConsumerOp(castOp))
3433 return failure();
3434
3435 auto newResultType = PadOp::inferResultType(
3436 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3437 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3438 padTensorOp.getResultType().getShape());
3439
3440 if (newResultType == padTensorOp.getResultType()) {
3441 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3442 padTensorOp.getSourceMutable().assign(castOp.getSource());
3443 });
3444 } else {
3445 auto newOp = PadOp::create(
3446 rewriter, padTensorOp->getLoc(), newResultType,
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3450 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3451 IRMapping mapper;
3452 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3453
3454 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3455 padTensorOp, padTensorOp.getResultType(), newOp);
3456 }
3457 return success();
3458 }
3459};
3460
3461// Fold CastOp using the result of PadOp back into the latter if it adds
3462// static information.
3463struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3464 using OpRewritePattern<PadOp>::OpRewritePattern;
3465
3466 LogicalResult matchAndRewrite(PadOp padTensorOp,
3467 PatternRewriter &rewriter) const override {
3468 if (!padTensorOp.getResult().hasOneUse())
3469 return failure();
3470 auto tensorCastOp =
3471 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3472 if (!tensorCastOp)
3473 return failure();
3474 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3475 tensorCastOp.getDest().getType()))
3476 return failure();
3477
3478 auto replacementOp = PadOp::create(
3479 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3480 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3481 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3482 padTensorOp.getHigh(), padTensorOp.getNofold(),
3483 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3484 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3485
3486 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3487 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3488 return success();
3489 }
3490};
3491
3492/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3493/// different dimensions. The pattern applies if the following preconditions
3494/// hold:
3495/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3496/// 2) the tensor::ExtractSliceOps have only unit-strides,
3497/// 3) the tensor::PadOps perform only high-padding,
3498/// 4) the tensor::PadOps have the same constant padding value,
3499/// 5) the tensor::PadOps do not have common padding dimensions,
3500/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3501/// zero-offset for every dimension.
3502/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3503/// the
3504/// padded source dimensions.
3505///
3506/// Example:
3507///
3508/// ```mlir
3509/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3510/// : tensor<64x64xf32> to tensor<?x64xf32>
3511/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3512/// } : tensor<?x64xf32> to tensor<8x64xf32>
3513/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3514/// : tensor<8x64xf32> to tensor<8x?xf32>
3515/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3516/// } : tensor<8x?xf32> to tensor<8x4xf32>
3517/// ```
3518///
3519/// folds into:
3520///
3521/// ```mlir
3522/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3523/// : tensor<64x64xf32> to tensor<?x?xf32>
3524/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3525/// } : tensor<?x?xf32> to tensor<8x4xf32>
3526/// ```
3527struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3528 using OpRewritePattern<PadOp>::OpRewritePattern;
3529
3530 LogicalResult matchAndRewrite(PadOp padOp,
3531 PatternRewriter &rewriter) const override {
3532 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3533 if (!innerSliceOp)
3534 return failure();
3535 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3536 if (!outerPadOp || outerPadOp.getNofold())
3537 return failure();
3538 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3539 if (!outerSliceOp)
3540 return failure();
3541
3542 // 1) Fail if the chain is rank-reducing.
3543 int64_t rank = padOp.getSourceType().getRank();
3544 if (outerSliceOp.getSourceType().getRank() != rank) {
3545 return rewriter.notifyMatchFailure(padOp,
3546 "cannot fold rank-reducing chain");
3547 }
3548
3549 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3550 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3551 return rewriter.notifyMatchFailure(
3552 padOp, "cannot fold non-unit stride ExtractSliceOps");
3553 }
3554
3555 // 3) Fail if the tensor::PadOps have non-zero low padding.
3556 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3557 return rewriter.notifyMatchFailure(padOp,
3558 "cannot fold PadOps with low padding");
3559 }
3560
3561 // 4) Fail if the tensor::PadOps padding values do not match.
3562 Attribute innerAttr, outerAttr;
3563 Value innerValue = padOp.getConstantPaddingValue();
3564 Value outerValue = outerPadOp.getConstantPaddingValue();
3565 if (!innerValue || !outerValue ||
3566 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3567 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3568 innerAttr != outerAttr) {
3569 return rewriter.notifyMatchFailure(
3570 padOp, "cannot fold PadOps with different padding values");
3571 }
3572
3573 // 5) Fail if a dimension is padded by both tensor::PadOps.
3574 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3575 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3576 if (innerDims.anyCommon(outerDims)) {
3577 return rewriter.notifyMatchFailure(
3578 padOp, "cannot fold PadOps with common padding dimensions");
3579 }
3580
3581 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3582 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3583 // for every dimension, and use the offset the other pair. Fail if no
3584 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3585 // exists.
3586 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3587 for (auto en : enumerate(newOffsets)) {
3588 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3589 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3590 if (!innerDims.test(en.index()) &&
3591 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3592 en.value() = outerOffset;
3593 continue;
3594 }
3595 if (!outerDims.test(en.index()) &&
3596 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3597 en.value() = innerOffset;
3598 continue;
3599 }
3600 return rewriter.notifyMatchFailure(
3601 padOp, "cannot find zero-offset and zero-padding pair");
3602 }
3603
3604 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3605 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3606 // outer tensor::PadOp and fail if the size of the inner
3607 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3608 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3609 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3610 for (auto en : enumerate(newSizes)) {
3611 if (!outerDims.test(en.index()))
3612 continue;
3613 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3614 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3615 assert(ShapedType::isStatic(sourceSize) &&
3616 "expected padded dimension to have a static size");
3617 if (getConstantIntValue(sliceSize) != sourceSize) {
3618 return rewriter.notifyMatchFailure(
3619 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3620 "match the size of the outer padding");
3621 }
3622 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3623 }
3624
3625 // Combine the high paddings of the two tensor::PadOps.
3626 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3627 for (auto en : enumerate(newHighPad)) {
3628 if (innerDims.test(en.index()))
3629 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3630 if (outerDims.test(en.index()))
3631 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3632 }
3633
3634 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3635 // the two paddings in one step.
3636 auto newSliceOp = ExtractSliceOp::create(
3637 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3638 newSizes, innerSliceOp.getMixedStrides());
3639 auto newPadOp = PadOp::create(
3640 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3641 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3642 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3643 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3644 newPadOp.getRegion().begin());
3645 rewriter.replaceOp(padOp, newPadOp.getResult());
3646 return success();
3647 }
3648};
3649
3650struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3651 using OpRewritePattern<PadOp>::OpRewritePattern;
3652
3653 LogicalResult matchAndRewrite(PadOp padTensorOp,
3654 PatternRewriter &rewriter) const override {
3655 Value input = padTensorOp.getSource();
3656 if (!llvm::isa<RankedTensorType>(input.getType()))
3657 return failure();
3658 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3659 auto inputRank = inputDims.size();
3660
3661 auto oldResultType =
3662 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3663 if (!oldResultType)
3664 return failure();
3665
3666 auto outputDims = oldResultType.getShape();
3667
3668 // Extract the static info from the high and low operands.
3669 SmallVector<int64_t> constOperandsLow;
3670 SmallVector<Value> newLows;
3671 for (auto operand : padTensorOp.getLow()) {
3672 APSInt intOp;
3673 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3674 constOperandsLow.push_back(ShapedType::kDynamic);
3675 newLows.push_back(operand);
3676 continue;
3677 }
3678 constOperandsLow.push_back(intOp.getExtValue());
3679 }
3680 SmallVector<int64_t> constOperandsHigh;
3681 SmallVector<Value> newHighs;
3682 for (auto operand : padTensorOp.getHigh()) {
3683 APSInt intOp;
3684 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3685 constOperandsHigh.push_back(ShapedType::kDynamic);
3686 newHighs.push_back(operand);
3687 continue;
3688 }
3689 constOperandsHigh.push_back(intOp.getExtValue());
3690 }
3691
3692 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3693 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3694
3695 // Verify the op is well-formed.
3696 if (inputDims.size() != outputDims.size() ||
3697 inputDims.size() != constLow.size() ||
3698 inputDims.size() != constHigh.size())
3699 return failure();
3700
3701 auto lowCount = 0;
3702 auto highCount = 0;
3703 for (size_t i = 0; i < inputRank; i++) {
3704 if (constLow[i] == ShapedType::kDynamic)
3705 constLow[i] = constOperandsLow[lowCount++];
3706 if (constHigh[i] == ShapedType::kDynamic)
3707 constHigh[i] = constOperandsHigh[highCount++];
3708 }
3709
3710 auto staticLow = ArrayRef<int64_t>(constLow);
3711 auto staticHigh = ArrayRef<int64_t>(constHigh);
3712
3713 // Calculate the output sizes with the static information.
3714 SmallVector<int64_t> newOutDims;
3715 for (size_t i = 0; i < inputRank; i++) {
3716 if (outputDims[i] == ShapedType::kDynamic) {
3717 newOutDims.push_back(
3718 (staticLow[i] == ShapedType::kDynamic ||
3719 staticHigh[i] == ShapedType::kDynamic ||
3720 inputDims[i] == ShapedType::kDynamic
3721 ? ShapedType::kDynamic
3722 : inputDims[i] + staticLow[i] + staticHigh[i]));
3723 } else {
3724 newOutDims.push_back(outputDims[i]);
3725 }
3726 }
3727
3728 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3729 llvm::all_of(newOutDims,
3730 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3731 return failure();
3732
3733 // Rewrite the op using the new static type.
3734 auto newResultType = RankedTensorType::get(
3735 newOutDims, padTensorOp.getType().getElementType());
3736 auto newOp = PadOp::create(
3737 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3738 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3739 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3740
3741 IRMapping mapper;
3742 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3743 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3744 newOp);
3745
3746 return success();
3747 }
3748};
3749
3750/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3751///
3752/// Example:
3753///
3754/// ```mlir
3755/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3756/// tensor.yield %val
3757/// } : tensor<1x2xf32> to tensor<2x5xf32>
3758/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3759/// tensor.yield %val
3760/// } : tensor<1x5xf32> to tensor<5x7xf32>
3761/// ```
3762///
3763/// folds into:
3764///
3765/// ```mlir
3766/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3767/// tensor.yield %val
3768/// } : tensor<1x2xf32> to tensor<5x7xf32>
3769/// ```
3770struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3771 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3772
3773 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3774 PatternRewriter &rewriter) const override {
3775 if (padOp.getNofold()) {
3776 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3777 }
3778
3779 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3780 if (!producerPad || producerPad.getNofold()) {
3781 return rewriter.notifyMatchFailure(
3782 padOp, "producer is not a foldable tensor.pad op");
3783 }
3784
3785 // Fail if the tensor::PadOps padding values do not match.
3786 Value consumerPadValue = padOp.getConstantPaddingValue();
3787 Value producerPadValue = producerPad.getConstantPaddingValue();
3788 if (!consumerPadValue || !producerPadValue ||
3789 consumerPadValue != producerPadValue) {
3790 return rewriter.notifyMatchFailure(
3791 padOp,
3792 "cannot fold PadOps with different or non-constant padding values");
3793 }
3794
3795 Location loc = padOp.getLoc();
3796 AffineExpr d0, d1;
3797 bindDims(rewriter.getContext(), d0, d1);
3798
3799 // Combine the low/high paddings of the two tensor::PadOps.
3800 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3801 ArrayRef<OpFoldResult> producerPaddings) {
3802 SmallVector<OpFoldResult> sumPaddings;
3803 for (auto [consumerIndex, producerIndex] :
3804 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3805 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3806 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3807 }
3808 return sumPaddings;
3809 };
3810
3811 SmallVector<OpFoldResult> newHighPad =
3812 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3813 SmallVector<OpFoldResult> newLowPad =
3814 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3815
3816 auto newPadOp = tensor::PadOp::create(
3817 rewriter, padOp.getLoc(), padOp.getResultType(),
3818 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3819 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3820 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3821 newPadOp.getRegion().begin());
3822 rewriter.replaceOp(padOp, newPadOp.getResult());
3823 return success();
3824 }
3825};
3826
3827} // namespace
3828
3829LogicalResult
3830PadOp::reifyResultShapes(OpBuilder &b,
3831 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3832 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3833 SmallVector<OpFoldResult> lp = getMixedLowPad();
3834 SmallVector<OpFoldResult> hp = getMixedHighPad();
3835 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3836 if (!getType().isDynamicDim(i)) {
3837 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3838 continue;
3839 }
3840 Location loc = getLoc();
3841 Value dim = b.createOrFold<tensor::DimOp>(
3842 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3843
3844 AffineExpr d0, d1, d2;
3845 bindDims(b.getContext(), d0, d1, d2);
3846 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3847 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3848 }
3849 return success();
3850}
3851
3852void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3853 MLIRContext *context) {
3854 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3855 FoldOrthogonalPaddings, FoldStaticPadding,
3856 FoldConsecutiveConstantPadding>(context);
3857}
3858
3859/// Return the padding value of the PadOp if it constant. In this context,
3860/// "constant" means an actual constant or "defined outside of the block".
3861///
3862/// Values are considered constant in three cases:
3863/// - A ConstantLike value.
3864/// - A basic block argument from a different block.
3865/// - A value defined outside of the block.
3866///
3867/// If the padding value is not constant, an empty Value is returned.
3868Value PadOp::getConstantPaddingValue() {
3869 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3870 if (!yieldOp)
3871 return {};
3872 Value padValue = yieldOp.getValue();
3873 // Check if yield value is a constant.
3874 if (matchPattern(padValue, m_Constant()))
3875 return padValue;
3876 // Check if yield value is defined inside the PadOp block.
3877 if (padValue.getParentBlock() == &getRegion().front())
3878 return {};
3879 // Else: Yield value defined outside of the PadOp block.
3880 return padValue;
3881}
3882
3883OpFoldResult PadOp::fold(FoldAdaptor) {
3884 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3885 !getNofold())
3886 return getSource();
3887 return {};
3888}
3889
3890//===----------------------------------------------------------------------===//
3891// ParallelInsertSliceOp
3892//===----------------------------------------------------------------------===//
3893
3894OpResult ParallelInsertSliceOp::getTiedOpResult() {
3895 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3896 for (const auto &it :
3897 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3898 Operation &nextOp = it.value();
3899 if (&nextOp == getOperation())
3900 return parallelCombiningParent.getParentResult(it.index());
3901 }
3902 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3903}
3904
3905// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3906void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3907 Value source, Value dest,
3908 ArrayRef<OpFoldResult> offsets,
3909 ArrayRef<OpFoldResult> sizes,
3910 ArrayRef<OpFoldResult> strides,
3911 ArrayRef<NamedAttribute> attrs) {
3912 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3913 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3914 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3915 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3916 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3917 result.addAttributes(attrs);
3918 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3919 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3920 b.getDenseI64ArrayAttr(staticSizes),
3921 b.getDenseI64ArrayAttr(staticStrides));
3922}
3923
3924/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3925/// packed into a Range vector.
3926void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3927 Value source, Value dest,
3928 ArrayRef<Range> ranges,
3929 ArrayRef<NamedAttribute> attrs) {
3930 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3931 build(b, result, source, dest, offsets, sizes, strides, attrs);
3932}
3933
3934// Build a ParallelInsertSliceOp with dynamic entries.
3935void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3936 Value source, Value dest, ValueRange offsets,
3937 ValueRange sizes, ValueRange strides,
3938 ArrayRef<NamedAttribute> attrs) {
3939 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3940 offsets, [](Value v) -> OpFoldResult { return v; });
3941 SmallVector<OpFoldResult> sizeValues =
3942 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3943 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3944 strides, [](Value v) -> OpFoldResult { return v; });
3945 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3946}
3947
3948// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3949// to 0, strides set to 1 and inferred result type.
3950void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3951 Value dest, ArrayRef<OpFoldResult> sizes,
3952 ArrayRef<NamedAttribute> attrs) {
3953 Attribute zeroIdxAttr = b.getIndexAttr(0);
3954 Attribute oneIdxAttr = b.getIndexAttr(1);
3955 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3956 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3957 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3958}
3959
3960LogicalResult ParallelInsertSliceOp::verify() {
3961 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3962 return this->emitError("expected InParallelOpInterface parent, got:")
3963 << *(getOperation()->getParentOp());
3964
3965 // Verify result type against inferred type.
3966 RankedTensorType expectedType;
3968 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3969 getStaticSizes(), getStaticStrides(), &expectedType);
3971 return produceSliceErrorMsg(result, *this, expectedType);
3972
3973 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3974 // to the destination tensor.
3975 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3976 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3977 getStaticStrides(), /*generateErrorMessage=*/true);
3978 if (!boundsResult.isValid)
3979 return getOperation()->emitError(boundsResult.errorMessage);
3980
3981 return success();
3982}
3983
3984void ParallelInsertSliceOp::getCanonicalizationPatterns(
3985 RewritePatternSet &results, MLIRContext *context) {
3986 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3987 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3988 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3989}
3990
3991llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3992 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3993}
3994
3995// ParallelCombiningOpInterface implementation.
3996MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3997 return getDestMutable();
3998}
3999
4000Operation *ParallelInsertSliceOp::getIteratingParent() {
4001 // Return the parent InParallelOpInterface's parent.
4002 if (auto combiningOp =
4003 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4004 return combiningOp->getParentOp();
4005 return nullptr;
4006}
4007
4008//===----------------------------------------------------------------------===//
4009// ScatterOp
4010//===----------------------------------------------------------------------===//
4011
4012void ScatterOp::getAsmResultNames(
4013 function_ref<void(Value, StringRef)> setNameFn) {
4014 setNameFn(getResult(), "scatter");
4015}
4016
4017LogicalResult ScatterOp::verify() {
4018 int64_t destRank = getDestType().getRank();
4019 ArrayRef<int64_t> scatterDims = getScatterDims();
4020 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
4021 getIndicesType().getShape(), destRank,
4022 "scatter", "dest")))
4023 return failure();
4024
4025 if (!getUnique())
4026 return emitOpError("requires 'unique' attribute to be set");
4027 // TODO: we could also check statically that there are fewer leading index
4028 // tensor dims than the dest dims. If this is not the case, the unique
4029 // attribute cannot be true.
4030
4031 // Use the GatherOp::inferResultType on the `dest` type and verify the
4032 // expected type matches the source type.
4033 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4034 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
4035 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4036 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
4037 if (getSourceType() != expectedSourceType &&
4038 getSourceType() != expectedRankReducedSourceType) {
4039 return emitOpError("source type "
4040 "mismatch: "
4041 "expected ")
4042 << expectedSourceType << " or its rank-reduced variant "
4043 << expectedRankReducedSourceType << " (got: " << getSourceType()
4044 << ")";
4045 }
4046
4047 return success();
4048}
4049
4050//===----------------------------------------------------------------------===//
4051// SplatOp
4052//===----------------------------------------------------------------------===//
4053
4054void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4055 Type aggregateType, ValueRange dynamicSizes) {
4056 build(builder, result, aggregateType, element, dynamicSizes);
4057}
4058
4059void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4060 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4061 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4062 build(builder, result, aggregateType, element, dynamicSizes);
4063}
4064
4065void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4066 ArrayRef<OpFoldResult> sizes) {
4067 SmallVector<int64_t> staticShape;
4068 SmallVector<Value> dynamicSizes;
4069 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4070 build(builder, result, element, staticShape, dynamicSizes);
4071}
4072
4073void SplatOp::getAsmResultNames(
4074 function_ref<void(Value, StringRef)> setNameFn) {
4075 setNameFn(getResult(), "splat");
4076}
4077
4078LogicalResult SplatOp::verify() {
4079 return verifyDynamicDimensionCount(getOperation(), getType(),
4080 getDynamicSizes());
4081}
4082
4083LogicalResult
4084SplatOp::reifyResultShapes(OpBuilder &builder,
4085 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4086 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4087 unsigned ctr = 0;
4088 for (int64_t i = 0; i < getType().getRank(); ++i) {
4089 if (getType().isDynamicDim(i)) {
4090 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4091 } else {
4092 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4093 }
4094 }
4095 return success();
4096}
4097
4098OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4099 auto constOperand = adaptor.getInput();
4100 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4101 return {};
4102
4103 // Do not fold if the splat is not statically shaped
4104 if (!getType().hasStaticShape())
4105 return {};
4106
4107 // SplatElementsAttr::get treats single value for second arg as being a
4108 // splat.
4109 return SplatElementsAttr::get(getType(), {constOperand});
4110}
4111
4112//===----------------------------------------------------------------------===//
4113// Common Canonicalizers and Folders.
4114//===----------------------------------------------------------------------===//
4115static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4116 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4117 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4118 // might need special handling of attached regions.
4119 if (isa<InsertSliceOp>(op.getOperation()) ||
4120 isa<LoopLikeOpInterface>(op.getOperation()))
4121 return false;
4122
4124}
4125
4126/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4127/// the `tensor.cast` has source that is more static than the consuming op.
4128///
4129/// Example:
4130/// ```mlir
4131/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4132/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4133/// ```
4134///
4135/// folds into:
4136///
4137/// ```mlir
4138/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4139/// ```
4140/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4141/// can add the pattern to their canonicalizers.
4143 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4145 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4146
4147 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4148 PatternRewriter &rewriter) const override {
4149
4150 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4151 // for that instead.
4152 if (!foldTensorCastPrecondition(op) ||
4153 isa<linalg::RelayoutOpInterface>(*op))
4154 return failure();
4155
4156 SmallVector<Type> newResultTypes(op->getResultTypes());
4157 SmallVector<Value> newOperands =
4158 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4159
4160 // Clone op
4161 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4162
4163 SmallVector<Value, 4> replacements;
4164 replacements.reserve(newOp->getNumResults());
4165 for (auto [oldResult, newResult] :
4166 llvm::zip(op->getResults(), newOp->getResults())) {
4167 if (newResult.getType() != oldResult.getType()) {
4168 replacements.push_back(tensor::CastOp::create(
4169 rewriter, op->getLoc(), oldResult.getType(), newResult));
4170 } else {
4171 replacements.push_back(newResult);
4172 }
4173 }
4174 rewriter.replaceOp(op, replacements);
4175
4176 return success();
4177 }
4178};
4179
4180//===----------------------------------------------------------------------===//
4181// TensorDialect
4182//===----------------------------------------------------------------------===//
4183
4184void TensorDialect::getCanonicalizationPatterns(
4185 RewritePatternSet &results) const {
4186 results.add<FoldTensorCastProducerOp>(getContext());
4187}
4188
4189//===----------------------------------------------------------------------===//
4190// TableGen'd op method definitions
4191//===----------------------------------------------------------------------===//
4192
4193#define GET_OP_CLASSES
4194#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:373
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:383
MLIRContext * getContext() const
Definition Builders.h:56
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:78
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:93
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.