MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/SmallVectorExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/Support/Casting.h"
40#include "llvm/Support/MathExtras.h"
41#include <optional>
42
43using namespace mlir;
44using namespace mlir::tensor;
45
46/// Materialize a single constant operation from a given attribute value with
47/// the desired resultant type.
48Operation *TensorDialect::materializeConstant(OpBuilder &builder,
49 Attribute value, Type type,
50 Location loc) {
51 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
52 return op;
53 if (complex::ConstantOp::isBuildableWith(value, type))
54 return complex::ConstantOp::create(builder, loc, type,
55 llvm::cast<ArrayAttr>(value));
56 return nullptr;
57}
58
60 int64_t dim) {
61 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
62 if (tensorType.isDynamicDim(dim))
63 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
64
65 return builder.getIndexAttr(tensorType.getDimSize(dim));
66}
67
69 Location loc, Value value) {
70 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
72 for (int64_t i = 0; i < tensorType.getRank(); ++i)
73 result.push_back(getMixedSize(builder, loc, value, i));
74 return result;
75}
76
78 OpResult opResult) {
79 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
80 assert(tensorType && "expected tensor type");
81
82 // If the op has a destination, it implements DestinationStyleOpInterface and
83 // we can query the destination operand from that interface.
84 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
85 if (destOp)
86 return destOp.getTiedOpOperand(opResult)->get();
87
88 // Otherwise, create a new destination tensor with the same shape.
90 b.setInsertionPoint(opResult.getDefiningOp());
91
92 // Compute sizes.
94 if (!tensorType.hasStaticShape()) {
95 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
96 ReifiedRankedShapedTypeDims reifiedShapes;
97 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
98 return failure();
99 mixedSizes = reifiedShapes[opResult.getResultNumber()];
100 } else {
101 // Static shape: Take static sizes directly.
102 for (int64_t sz : tensorType.getShape())
103 mixedSizes.push_back(b.getIndexAttr(sz));
104 }
105
106 // Create empty tensor.
107 Value emptyTensor =
108 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
109 return emptyTensor;
110}
111
113 Operation *op,
115 for (OpResult opResult : op->getResults()) {
116 if (llvm::isa<TensorType>(opResult.getType())) {
117 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
118 if (failed(destination))
119 return failure();
120 result.push_back(*destination);
121 }
122 }
123 return success();
124}
125
127 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
128 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
129 return rtp1.getShape() == rtp2.getShape() &&
130 rtp1.getElementType() == rtp2.getElementType();
131 return false;
132 }
133 return tp1 == tp2; // default implementation
134}
135
136/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
137/// rank-extending tensor.insert_slice op.
138static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
139 ArrayRef<OpFoldResult> mixedSizes) {
140 llvm::SmallBitVector droppedDims(mixedSizes.size());
141 int64_t shapePos = reducedShape.size() - 1;
142
143 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
144 size_t idx = mixedSizes.size() - size.index() - 1;
145 // Rank-reduced dims must have a static unit dimension.
146 bool isStaticUnitSize =
147 isa<Attribute>(size.value()) &&
148 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
149
150 if (shapePos < 0) {
151 // There are no more dims in the reduced shape. All remaining sizes must
152 // be rank-reduced dims.
153 assert(isStaticUnitSize && "expected unit dim");
154 droppedDims.set(idx);
155 continue;
156 }
157
158 // Dim is preserved if the size is not a static 1.
159 if (!isStaticUnitSize) {
160 --shapePos;
161 continue;
162 }
163
164 // Dim is preserved if the reduced shape dim is also 1.
165 if (reducedShape[shapePos] == 1) {
166 --shapePos;
167 continue;
168 }
169
170 // Otherwise: Dim is dropped.
171 droppedDims.set(idx);
172 }
173
174 assert(shapePos < 0 && "dimension mismatch");
175 return droppedDims;
176}
177
178/// Given a ranked tensor type and a range of values that defines its dynamic
179/// dimension sizes, turn all dynamic sizes that have a constant value into
180/// static dimension sizes.
181static RankedTensorType
182foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
183 SmallVector<Value> &foldedDynamicSizes) {
184 SmallVector<int64_t> staticShape(type.getShape());
185 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
186 "incorrect number of dynamic sizes");
187
188 // Compute new static and dynamic sizes.
189 unsigned ctr = 0;
190 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
191 if (type.isDynamicDim(i)) {
192 Value dynamicSize = dynamicSizes[ctr++];
193 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
194 if (cst.has_value()) {
195 // Dynamic size must be non-negative.
196 if (cst.value() < 0) {
197 foldedDynamicSizes.push_back(dynamicSize);
198 continue;
199 }
200 staticShape[i] = *cst;
201 } else {
202 foldedDynamicSizes.push_back(dynamicSize);
203 }
204 }
205 }
206
207 return RankedTensorType::get(staticShape, type.getElementType(),
208 type.getEncoding());
209}
210
211//===----------------------------------------------------------------------===//
212// BitcastOp
213//===----------------------------------------------------------------------===//
214
215bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
216 if (inputs.size() != 1 || outputs.size() != 1)
217 return false;
218 Type a = inputs.front(), b = outputs.front();
219 auto aT = dyn_cast<TensorType>(a);
220 auto bT = dyn_cast<TensorType>(b);
221 if (!aT || !bT)
222 return false;
223
224 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
225 return false;
226
227 return succeeded(verifyCompatibleShape(aT, bT));
228}
229
230namespace {
231
232/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
233/// operation.
234struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
235 using OpRewritePattern<BitcastOp>::OpRewritePattern;
236
237 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
238 PatternRewriter &rewriter) const final {
239 auto tensorBitcastOperand =
240 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
241 if (!tensorBitcastOperand)
242 return failure();
243
244 auto resultType = cast<TensorType>(tensorBitcast.getType());
245 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
246 tensorBitcastOperand.getOperand());
247 return success();
248 }
249};
250
251} // namespace
252
253void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
254 MLIRContext *context) {
255 results.add<ChainedTensorBitcast>(context);
256}
257
258//===----------------------------------------------------------------------===//
259// CastOp
260//===----------------------------------------------------------------------===//
261
262void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
263 setNameFn(getResult(), "cast");
264}
265
266/// Returns true if `target` is a ranked tensor type that preserves static
267/// information available in the `source` ranked tensor type.
269 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
270 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271
272 // Requires RankedTensorType.
273 if (!sourceType || !targetType)
274 return false;
275
276 // Requires same elemental type.
277 if (sourceType.getElementType() != targetType.getElementType())
278 return false;
279
280 // Requires same rank.
281 if (sourceType.getRank() != targetType.getRank())
282 return false;
283
284 // Requires same encoding.
285 if (sourceType.getEncoding() != targetType.getEncoding())
286 return false;
287
288 // If cast is towards more static sizes along any dimension, don't fold.
289 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
290 if (ShapedType::isStatic(std::get<0>(t)) &&
291 ShapedType::isDynamic(std::get<1>(t)))
292 return false;
293 }
294
295 return true;
296}
297
298/// Determines whether tensor::CastOp casts to a more dynamic version of the
299/// source tensor. This is useful to fold a tensor.cast into a consuming op and
300/// implement canonicalization patterns for ops in different dialects that may
301/// consume the results of tensor.cast operations. Such foldable tensor.cast
302/// operations are typically inserted as `slice` ops and are canonicalized,
303/// to preserve the type compatibility of their uses.
304///
305/// Returns true when all conditions are met:
306/// 1. source and result are ranked tensors with same element type and rank.
307/// 2. the tensor type has more static information than the result
308///
309/// Example:
310/// ```mlir
311/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
312/// %2 = consumer %1 ... : tensor<?x?xf32> ...
313/// ```
314///
315/// folds into:
316///
317/// ```mlir
318/// %2 = consumer %0 ... : tensor<8x16xf32> ...
319/// ```
321 if (!castOp)
322 return false;
323
324 // Can fold if the source of cast has at least as much static information as
325 // its results.
326 return preservesStaticInformation(castOp.getType(),
327 castOp.getSource().getType());
328}
329
330/// Determines whether the tensor::CastOp casts to a more static version of the
331/// source tensor. This is useful to fold into a producing op and implement
332/// canonicalization patterns with the `tensor.cast` op as the root, but
333/// producer being from different dialects. Returns true when all conditions are
334/// met:
335/// 1. source and result and ranked tensors with same element type and rank.
336/// 2. the result type has more static information than the source.
337///
338/// Example:
339/// ```mlir
340/// %1 = producer ... : tensor<?x?xf32>
341/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
342/// ```
343///
344/// can be canonicalized to :
345///
346/// ```mlir
347/// %2 = producer ... : tensor<8x16xf32>
348/// ```
349/// Not all ops might be canonicalizable this way, but for those that can be,
350/// this method provides a check that it is worth doing the canonicalization.
352 if (!castOp)
353 return false;
354 return preservesStaticInformation(castOp.getSource().getType(),
355 castOp.getType());
356}
357
359 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
360 if (llvm::isa<BlockArgument>(opOperand.get()))
361 return false;
362 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363 return castOp && canFoldIntoConsumerOp(castOp);
364 });
365}
366
368 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
369 SmallVector<Value> newOperands;
370 newOperands.reserve(op->getNumOperands());
371
372 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
373
374 // Assumes that the result has dpsInits followed by nonDpsInits.
375 int64_t dpsInitIdx = 0;
376 for (OpOperand &opOperand : op->getOpOperands()) {
377 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
378 bool fold = canFoldIntoConsumerOp(tensorCastOp);
379 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380 if (op.isDpsInit(&opOperand) &&
381 !llvm::isa<MemRefType>(newOperands.back().getType()))
382 newResTy[dpsInitIdx++] = newOperands.back().getType();
383 }
384 return newOperands;
385}
386
387/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
388/// that can be folded.
390 bool folded = false;
391 for (OpOperand &operand : op->getOpOperands()) {
392 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
393 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
394 operand.set(castOp.getOperand());
395 folded = true;
396 }
397 }
398 return success(folded);
399}
400
401bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
402 if (inputs.size() != 1 || outputs.size() != 1)
403 return false;
404 Type a = inputs.front(), b = outputs.front();
405 auto aT = llvm::dyn_cast<TensorType>(a);
406 auto bT = llvm::dyn_cast<TensorType>(b);
407 if (!aT || !bT)
408 return false;
409
410 if (aT.getElementType() != bT.getElementType())
411 return false;
412
413 return succeeded(verifyCompatibleShape(aT, bT));
414}
415
416/// Compute a TensorType that has the joined shape knowledge of the two
417/// given TensorTypes. The element types need to match.
419 assert(one.getElementType() == two.getElementType());
420
421 if (!one.hasRank())
422 return two;
423 if (!two.hasRank())
424 return one;
425
426 int64_t rank = one.getRank();
427 if (rank != two.getRank())
428 return {};
429
431 join.reserve(rank);
432 for (int64_t i = 0; i < rank; ++i) {
433 if (one.isDynamicDim(i)) {
434 join.push_back(two.getDimSize(i));
435 continue;
436 }
437 if (two.isDynamicDim(i)) {
438 join.push_back(one.getDimSize(i));
439 continue;
440 }
441 if (one.getDimSize(i) != two.getDimSize(i))
442 return {};
443 join.push_back(one.getDimSize(i));
444 }
445 return RankedTensorType::get(join, one.getElementType());
446}
447
448namespace {
449
450/// Replaces chains of two tensor.cast operations by a single tensor.cast
451/// operation if doing so does not remove runtime constraints.
452struct ChainedTensorCast : public OpRewritePattern<CastOp> {
453 using OpRewritePattern<CastOp>::OpRewritePattern;
454
455 LogicalResult matchAndRewrite(CastOp tensorCast,
456 PatternRewriter &rewriter) const final {
457 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
458
459 if (!tensorCastOperand)
460 return failure();
461
462 auto sourceType =
463 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
466
467 // We can remove the intermediate cast if joining all three produces the
468 // same result as just joining the source and result shapes.
469 auto firstJoin =
470 joinShapes(joinShapes(sourceType, intermediateType), resultType);
471
472 // The join might not exist if the cast sequence would fail at runtime.
473 if (!firstJoin)
474 return failure();
475
476 // The newJoin always exists if the above join exists, it might just contain
477 // less information. If so, we cannot drop the intermediate cast, as doing
478 // so would remove runtime checks.
479 auto newJoin = joinShapes(sourceType, resultType);
480 if (firstJoin != newJoin)
481 return failure();
482
483 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484 tensorCastOperand.getOperand());
485 return success();
486 }
487};
488
489/// Fold tensor.cast into tesor.extract_slice producer.
490/// Example:
491/// ```
492/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
493/// tensor<128x512xf32> to tensor<?x512xf32>
494/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
495/// ```
496/// ->
497/// ```
498/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
499/// tensor<128x512xf32> to tensor<16x512xf32>
500/// ```
501struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
502 using OpRewritePattern<CastOp>::OpRewritePattern;
503
504 LogicalResult matchAndRewrite(CastOp tensorCast,
505 PatternRewriter &rewriter) const final {
506 auto extractOperand =
507 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508
509 // Cannot fold cast to unranked tensor.
510 auto rankedResultType =
511 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512 if (!rankedResultType)
513 return failure();
514
515 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
516 rankedResultType.getShape() ==
517 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
518 .getShape())
519 return failure();
520
521 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
522 auto dimMask = computeRankReductionMask(
523 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
524 size_t dimIndex = 0;
525 for (size_t i = 0, e = sizes.size(); i < e; i++) {
526 if (dimMask && dimMask->count(i))
527 continue;
528 int64_t dim = rankedResultType.getShape()[dimIndex++];
529 if (ShapedType::isDynamic(dim))
530 continue;
531 sizes[i] = rewriter.getIndexAttr(dim);
532 }
533
534 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535 tensorCast, rankedResultType, extractOperand.getSource(),
536 extractOperand.getMixedOffsets(), sizes,
537 extractOperand.getMixedStrides());
538 return success();
539 }
540};
541
542} // namespace
543
544void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
545 MLIRContext *context) {
546 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
547}
548
549//===----------------------------------------------------------------------===//
550// ConcatOp
551//===----------------------------------------------------------------------===//
552
553RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
554 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
555 auto tensorTypes =
556 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
557 int64_t concatRank = tensorTypes[0].getRank();
558
559 // The concatenation dim must be in the range [0, rank).
560 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
561
562 SmallVector<int64_t> sizes(concatRank);
563 for (int64_t i = 0, e = concatRank; i < e; ++i) {
564 if (i == dim)
565 continue;
566 SaturatedInteger size;
567 for (auto tensorType : tensorTypes)
568 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
569 sizes[i] = size.asInteger();
570 }
571 auto concatSize = SaturatedInteger::wrap(0);
572 for (auto tensorType : tensorTypes)
573 concatSize =
574 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
577}
578
579void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
580 ValueRange inputs) {
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.getTypes());
583 assert(succeeded(resultType) && "failed to infer concatenation result type");
584 build(builder, result, *resultType, dim, inputs);
585}
586
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
589 return emitOpError("requires at least one input");
590
592 for (auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
594
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
599 }))
600 return emitOpError("rank of concatenated inputs must match result rank");
601
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
605 }))
606 return emitOpError("inputs and result element type must match");
607
608 int64_t dim = getDim();
609 if (dim >= resultRank)
610 return emitOpError("concatenation dim must be less than the tensor rank");
611
612 SmallVector<int64_t> sizes(resultRank);
613 for (int64_t i = 0, e = resultRank; i < e; ++i) {
614 if (i == dim)
615 continue;
616 SaturatedInteger size;
617 for (auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
619 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
620 if (failed(maybeSize))
621 return emitOpError("static concatenation size mismatch along ")
622 << "non-concatenated dimension " << i;
623 size = *maybeSize;
624 }
625 sizes[i] = size.asInteger();
626 }
627 auto concatSize = SaturatedInteger::wrap(0);
628 for (auto tensorType : inputTypes)
629 concatSize =
630 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
633 RankedTensorType::get(sizes, inputTypes[0].getElementType());
634
635 for (auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
640 return emitOpError("result type ")
641 << resultType << "does not match inferred shape "
642 << inferredResultType << " static sizes";
643 }
644
645 return success();
646}
647
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
651
653 inputShapes.reserve(numInputs);
654 SmallVector<OpFoldResult> concatOffsets;
655 concatOffsets.reserve(numInputs);
656 SmallVector<OpFoldResult> outputShape;
657
658 AffineExpr addExpr =
659 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
660 OpFoldResult zero = builder.getIndexAttr(0);
661 Location loc = getLoc();
662 for (auto [index, input] : llvm::enumerate(getInputs())) {
663 SmallVector<OpFoldResult> inputShape =
664 tensor::getMixedSizes(builder, input.getLoc(), input);
665 if (index == 0) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
668 } else {
669 concatOffsets.push_back(outputShape[concatDim]);
670 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
673 }
674 inputShapes.emplace_back(std::move(inputShape));
675 }
676
677 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
679
680 int64_t rank = getType().getRank();
681 OpFoldResult one = builder.getIndexAttr(1);
682 SmallVector<OpFoldResult> strides(rank, one);
683 SmallVector<OpFoldResult> offsets(rank, zero);
684 for (auto [index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[index];
686 auto insertSlice = tensor::InsertSliceOp::create(
687 builder, loc, input, replacement, offsets, inputShapes[index], strides);
688 replacement = insertSlice.getResult();
689 }
690 if (replacement.getType() != getType()) {
691 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
692 }
694}
695
696LogicalResult
697ConcatOp::reifyResultShapes(OpBuilder &builder,
698 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
699 ValueRange inputs = getInputs();
700 int64_t dim = getDim();
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
702
703 Value init = inputs[0];
704 int64_t rank = getType().getRank();
705
706 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
707
708 // Pre-populate the result sizes with as much static information as possible
709 // from the given result type, as well as the inferred result type, otherwise
710 // use the dim sizes from the first input.
711 for (int64_t i = 0; i < rank; ++i) {
712 if (i == dim)
713 continue;
714 if (!getType().isDynamicDim(i)) {
715 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
716 } else if (!inferredResultType.isDynamicDim(i)) {
717 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
718 builder, getLoc(),
719 builder.getIndexAttr(inferredResultType.getDimSize(i)));
720 } else {
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
723 }
724 }
725
726 if (getType().isDynamicDim(dim)) {
727 // Take the sum of the input sizes along the concatenated dim.
728 AffineExpr sum = builder.getAffineDimExpr(0);
730 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
731 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732 sum = sum + builder.getAffineDimExpr(idx + 1);
733 sizes.push_back(
734 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
735 }
736 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
737 builder, getLoc(),
738 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
739 } else {
740 // If the result shape is static along the concatenated dim, use the static
741 // shape.
742 reifiedReturnShapes[0][dim] =
743 builder.getIndexAttr(getType().getDimSize(dim));
744 }
745 return success();
746}
747
748void ConcatOp::getAsmResultNames(
749 function_ref<void(Value, StringRef)> setNameFn) {
750 setNameFn(getResult(), "concat");
751}
752
753OpFoldResult ConcatOp::fold(FoldAdaptor) {
754 ValueRange inputs = getInputs();
755 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
756 return inputs[0];
757 return {};
758}
759
760namespace {
761/// Fold a concat op with a single input to a cast.
762struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
764
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter) const override {
767 if (concatOp.getInputs().size() != 1)
768 return failure();
769 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
770 concatOp.getInputs()[0]);
771 return success();
772 }
773};
774
775/// Propagate static shapes into the operands of a `tensor.concat`.
776///
777/// `tensor.concat` requires every operand to match on all dimensions except the
778/// concatenation dimension. If one operand is already static in those
779/// dimensions, the other operands may safely be refined to that same static
780/// shape.
781///
782/// Example:
783///
784/// ```mlir
785/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
786/// tensor<?x12xi32>
787/// ```
788/// ->
789/// ```mlir
790/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
791/// %2 = tensor.concat dim(0) %0, %cast :
792/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
793/// ```
794struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
796
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter) const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802
803 // Find operands for which a more static shape can be inferred.
804 LogicalResult matched = failure();
805 // Inferred operand shapes are identical in every dimension except the
806 // concatenation dimension.
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
810 // Compute inferred type for operand.
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
815
816 // Check if inferred type is more static.
817 if (!preservesStaticInformation(inferredOperandType, operandType)) {
818 matched = success();
819
820 // Use refined operand type and create cast from original operand.
821 auto castOp =
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
824 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
826 });
827 }
828 }
829
830 return matched;
831 }
832};
833
834// Ensure `tensor.concat`'s result type is at least as static as can be inferred
835// from its operand types.
836///
837/// Example:
838/// ```mlir
839/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
840/// tensor<?x?xi32>
841/// ```
842/// ->
843/// ```mlir
844/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
845/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
846/// tensor<?x?xi32>
847/// ```
848struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
850
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter) const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
856
857 // The result type should be at least as static as inferred result type.
858 if (preservesStaticInformation(inferredResultType,
859 concatOp.getResultType())) {
860 return failure();
861 }
862
863 auto newConcatOp =
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
866 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
867 newConcatOp);
868
869 return success();
870 }
871};
872} // namespace
873
874void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
875 MLIRContext *context) {
876 results
877 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
878 context);
879}
880
881//===----------------------------------------------------------------------===//
882// DimOp
883//===----------------------------------------------------------------------===//
884
885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886 setNameFn(getResult(), "dim");
887}
888
889void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890 int64_t index) {
891 auto loc = result.location;
892 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893 build(builder, result, source, indexValue);
894}
895
896std::optional<int64_t> DimOp::getConstantIndex() {
898}
899
900Speculation::Speculatability DimOp::getSpeculatability() {
901 auto constantIndex = getConstantIndex();
902 if (!constantIndex)
904
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
906 if (!rankedSourceType)
908
909 if (rankedSourceType.getRank() <= constantIndex)
911
913}
914
915void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916 SetIntLatticeFn setResultRange) {
917 setResultRange(getResult(),
918 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919}
920
921OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
922 // All forms of folding require a known index.
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
924 if (!index)
925 return {};
926
927 // Folding for unranked types (UnrankedTensorType) is not supported.
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
929 if (!tensorType)
930 return {};
931
932 // Out of bound indices produce undefined behavior but are still valid IR.
933 // Don't choke on them.
934 int64_t indexVal = index.getInt();
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
936 return {};
937
938 // Fold if the shape extent along the given index is known.
939 if (!tensorType.isDynamicDim(index.getInt())) {
940 Builder builder(getContext());
941 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
942 }
943
944 Operation *definingOp = getSource().getDefiningOp();
945
946 // Fold dim to the operand of tensor.generate.
947 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
948 auto resultType =
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950 // The case where the type encodes the size of the dimension is handled
951 // above.
952 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
953
954 // Find the operand of the fromElements that corresponds to this index.
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (auto dim : resultType.getShape().take_front(index.getInt()))
957 if (ShapedType::isDynamic(dim))
958 dynExtents++;
959
960 return Value{*dynExtents};
961 }
962
963 // The size at the given index is now known to be a dynamic size.
964 unsigned unsignedIndex = index.getValue().getZExtValue();
965
966 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
968 // `resolve-shaped-type-result-dims` pass.
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
972 }
973 }
974
975 // dim(cast) -> dim
976 if (succeeded(foldTensorCast(*this)))
977 return getResult();
978
979 return {};
980}
981
982namespace {
983/// Fold dim of a cast into the dim of the source of the tensor cast.
984struct DimOfCastOp : public OpRewritePattern<DimOp> {
985 using OpRewritePattern<DimOp>::OpRewritePattern;
986
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter) const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990 if (!castOp)
991 return failure();
992 Value newSource = castOp.getOperand();
993 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
994 return success();
995 }
996};
997
998/// Fold dim of a destination passing style op into the dim of the corresponding
999/// init.
1000struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter) const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1007 if (!destOp)
1008 return failure();
1009
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012
1013 rewriter.modifyOpInPlace(
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1015 return success();
1016 }
1017};
1018
1019/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1020/// operand.
1021struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1023
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter) const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1027
1028 if (!reshape)
1029 return failure();
1030
1031 // Since tensors are immutable we don't need to worry about where to place
1032 // the extract call
1033 rewriter.setInsertionPointAfter(dim);
1034 Location loc = dim.getLoc();
1035 Value extract =
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.getType() != dim.getType())
1038 extract =
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1040 rewriter.replaceOp(dim, extract);
1041 return success();
1042 }
1043};
1044} // namespace
1045
1046void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1047 MLIRContext *context) {
1048 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// EmptyOp
1053//===----------------------------------------------------------------------===//
1054
1055void EmptyOp::build(OpBuilder &builder, OperationState &result,
1056 ArrayRef<int64_t> staticShape, Type elementType,
1057 Attribute encoding) {
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1060 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1061}
1062
1063void EmptyOp::build(OpBuilder &builder, OperationState &result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder, result, tensorType, dynamicSizes);
1068}
1069
1070void EmptyOp::build(OpBuilder &builder, OperationState &result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1075 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1076 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1077}
1078
1079LogicalResult EmptyOp::verify() {
1080 return verifyDynamicDimensionCount(getOperation(), getType(),
1081 getDynamicSizes());
1082}
1083
1084LogicalResult
1085EmptyOp::reifyResultShapes(OpBuilder &builder,
1086 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1087 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1088 unsigned ctr = 0;
1089 for (int64_t i = 0; i < getType().getRank(); ++i) {
1090 if (getType().isDynamicDim(i)) {
1091 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1092 } else {
1093 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1094 }
1095 }
1096 return success();
1097}
1098
1099Value EmptyOp::getDynamicSize(unsigned idx) {
1100 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1101 unsigned ctr = 0;
1102 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1103 if (getType().isDynamicDim(i))
1104 ++ctr;
1105 return getDynamicSizes()[ctr];
1106}
1107
1108SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1109 SmallVector<OpFoldResult> result;
1110 unsigned ctr = 0;
1111 Builder b(getContext());
1112 for (int64_t dim : getType().getShape()) {
1113 if (ShapedType::isDynamic(dim)) {
1114 result.push_back(getDynamicSizes()[ctr++]);
1115 } else {
1116 result.push_back(b.getIndexAttr(dim));
1117 }
1118 }
1119 return result;
1120}
1121
1122namespace {
1123/// Change the type of the result of a `tensor.empty` by making the result
1124/// type statically sized along dimensions that in the original operation were
1125/// defined as dynamic, but the size was defined using a `constant` op. For
1126/// example
1127///
1128/// %c5 = arith.constant 5: index
1129/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1130///
1131/// to
1132///
1133/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1134struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1135 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1136
1137 LogicalResult matchAndRewrite(EmptyOp op,
1138 PatternRewriter &rewriter) const override {
1139 SmallVector<Value> foldedDynamicSizes;
1140 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1141 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1142
1143 // Stop here if no dynamic size was promoted to static.
1144 if (foldedTensorType == op.getType())
1145 return failure();
1146
1147 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1148 foldedDynamicSizes);
1149 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1150 return success();
1151 }
1152};
1153
1154struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1155 using OpRewritePattern<DimOp>::OpRewritePattern;
1156
1157 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1158 PatternRewriter &rewriter) const override {
1159 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1160 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1161 if (!emptyTensorOp || !maybeConstantIndex)
1162 return failure();
1163 auto emptyTensorType = emptyTensorOp.getType();
1164 if (*maybeConstantIndex < 0 ||
1165 *maybeConstantIndex >= emptyTensorType.getRank() ||
1166 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1167 return failure();
1168 rewriter.replaceOp(dimOp,
1169 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1170 return success();
1171 }
1172};
1173
1174/// Canonicalize
1175///
1176/// ```mlir
1177/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1178/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1179/// ```
1180///
1181/// into
1182///
1183/// ```mlir
1184/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1185/// ```
1186///
1187/// This assumes the input program is correct in terms of its shape. So it is
1188/// safe to assume that `%d0` is in fact 4.
1189struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1190 using OpRewritePattern<CastOp>::OpRewritePattern;
1191
1192 LogicalResult matchAndRewrite(CastOp castOp,
1193 PatternRewriter &rewriter) const override {
1194 if (!canFoldIntoProducerOp(castOp))
1195 return failure();
1196 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1197 if (!producer)
1198 return failure();
1199
1200 auto resultType =
1201 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1202 ArrayRef<int64_t> resultShape = resultType.getShape();
1203 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1204 SmallVector<OpFoldResult> newMixedSizes;
1205 newMixedSizes.reserve(currMixedSizes.size());
1206 assert(resultShape.size() == currMixedSizes.size() &&
1207 "mismatch in result shape and sizes of empty op");
1208 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1209 int64_t newDim = std::get<0>(it);
1210 OpFoldResult currDim = std::get<1>(it);
1211 // Case 1: The empty tensor dim is static. Check that the tensor cast
1212 // result dim matches.
1213 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1214 if (ShapedType::isDynamic(newDim) ||
1215 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1216 // Something is off, the cast result shape cannot be more dynamic
1217 // than the empty tensor result shape (enforced by
1218 // `canFoldIntoProducer`). Abort for now.
1219 return rewriter.notifyMatchFailure(
1220 producer, "mismatch in static value of shape of empty tensor "
1221 "result and cast result");
1222 }
1223 newMixedSizes.push_back(attr);
1224 continue;
1225 }
1226
1227 // Case 2 : The tensor cast shape is static, but empty tensor result
1228 // shape is dynamic.
1229 if (ShapedType::isStatic(newDim)) {
1230 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1231 continue;
1232 }
1233
1234 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1235 // shape is dynamic. Use the dynamic value from the empty tensor op.
1236 newMixedSizes.push_back(currDim);
1237 }
1238
1239 // TODO: Do not drop tensor encoding.
1240 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1241 resultType.getElementType());
1242 return success();
1243 }
1244};
1245
1246} // namespace
1247
1248void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1251 ReplaceEmptyTensorStaticShapeDims>(context);
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// ExtractOp
1256//===----------------------------------------------------------------------===//
1257
1258namespace {
1259
1260/// Canonicalizes the pattern of the form
1261///
1262/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1263/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1264///
1265/// to
1266///
1267/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1268struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1269 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1270
1271 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1272 PatternRewriter &rewriter) const final {
1273 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1274 if (!tensorCast)
1275 return failure();
1276 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1277 return failure();
1278 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1279 extract, tensorCast.getSource(), extract.getIndices());
1280 return success();
1281 }
1282};
1283
1284/// Canonicalizes the pattern of the form
1285///
1286/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1287/// tensor<12xf64>
1288/// %extracted_element = tensor.extract %val[%c10] :
1289/// tensor<12xf64>
1290///
1291/// to
1292///
1293/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1294struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1295 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1296
1297 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1298 PatternRewriter &rewriter) const final {
1299 auto collapseOp =
1300 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1301 if (!collapseOp)
1302 return failure();
1303 if (!collapseOp.getSrcType().hasStaticShape())
1304 return failure();
1305
1306 auto sourceSizes = collapseOp.getSrcType().getShape();
1307
1308 SmallVector<Value> indices(extractOp.getIndices().begin(),
1309 extractOp.getIndices().end());
1310 SmallVector<Value> sourceIndices;
1311 for (auto [index, group] :
1312 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1313 assert(!group.empty() && "association indices groups cannot be empty");
1314 auto groupSize = group.size();
1315
1316 if (groupSize == 1) {
1317 sourceIndices.push_back(index);
1318 continue;
1319 }
1320
1321 SmallVector<int64_t> basis =
1322 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1323 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1324 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1325 llvm::append_range(sourceIndices, delinearize.getResults());
1326 }
1327 if (collapseOp.getReassociationIndices().empty()) {
1328 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1329 int64_t srcRank =
1330 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1332 rewriter, extractOp.getLoc(), zeroAffineMap,
1333 ArrayRef<OpFoldResult>{});
1334 for (int64_t i = 0; i < srcRank; i++) {
1335 sourceIndices.push_back(
1336 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1337 }
1338 }
1339
1340 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1341 extractOp, collapseOp.getSrc(), sourceIndices);
1342 return success();
1343 }
1344};
1345
1346} // namespace
1347
1348void ExtractOp::getAsmResultNames(
1349 function_ref<void(Value, StringRef)> setNameFn) {
1350 setNameFn(getResult(), "extracted");
1351}
1352
1353LogicalResult ExtractOp::verify() {
1354 // Verify the # indices match if we have a ranked type.
1355 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1356 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1357 return emitOpError("incorrect number of indices for extract_element");
1358 return success();
1359}
1360
1361/// If we have an ExtractOp consuming an InsertOp with the same
1362/// indices, we can return the InsertOp's scalar directly.
1363// TODO: This only checks the immediate producer; extend to go up the
1364// insert/extract chain if the slices are disjoint.
1365static Value foldExtractAfterInsert(ExtractOp extractOp) {
1366 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1367
1368 auto isSame = [](Value a, Value b) {
1370 };
1371 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1372 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1373 return insertOp.getScalar();
1374
1375 return {};
1376}
1377
1378OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1379 if (Attribute tensor = adaptor.getTensor()) {
1380 // If this is a splat elements attribute, simply return the value.
1381 // All of the elements of a splat attribute are the same.
1382 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1383 return splatTensor.getSplatValue<Attribute>();
1384
1385 // If this is a dense resource elements attribute, return.
1386 if (isa<DenseResourceElementsAttr>(tensor))
1387 return {};
1388 }
1389
1390 // Collect the constant indices into the tensor.
1391 SmallVector<uint64_t, 8> indices;
1392 for (Attribute indice : adaptor.getIndices()) {
1393 if (!indice || !llvm::isa<IntegerAttr>(indice))
1394 return {};
1395 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1396 }
1397
1398 // Fold extract(from_elements(...)).
1399 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1400 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1401 auto rank = tensorType.getRank();
1402 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1403 "rank mismatch");
1404 int flatIndex = 0;
1405 int stride = 1;
1406 for (int i = rank - 1; i >= 0; --i) {
1407 flatIndex += indices[i] * stride;
1408 stride *= tensorType.getDimSize(i);
1409 }
1410 // Prevent out of bounds accesses. This can happen in invalid code that
1411 // will never execute.
1412 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1413 flatIndex < 0)
1414 return {};
1415 return fromElementsOp.getElements()[flatIndex];
1416 }
1417
1418 // If this is an elements attribute, query the value at the given indices.
1419 if (Attribute tensor = adaptor.getTensor()) {
1420 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1421 if (elementsAttr && elementsAttr.isValidIndex(indices))
1422 return elementsAttr.getValues<Attribute>()[indices];
1423 }
1424
1425 if (Value result = foldExtractAfterInsert(*this))
1426 return result;
1427
1428 return {};
1429}
1430
1431void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1432 MLIRContext *context) {
1433 results.add<ExtractFromTensorCast>(context);
1434}
1435
1438 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// FromElementsOp
1443//===----------------------------------------------------------------------===//
1444
1445void FromElementsOp::getAsmResultNames(
1446 function_ref<void(Value, StringRef)> setNameFn) {
1447 setNameFn(getResult(), "from_elements");
1448}
1449
1450void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1451 ValueRange elements) {
1452 assert(!elements.empty() && "expected at least one element");
1453 Type resultType = RankedTensorType::get(
1454 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1455 build(builder, result, resultType, elements);
1456}
1457
1458OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1459 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1460 return DenseElementsAttr::get(getType(), adaptor.getElements());
1461 return {};
1462}
1463
1464namespace {
1465
1466// Pushes the index_casts that occur before extractions to after the extract.
1467// This minimizes type conversion in some cases and enables the extract
1468// canonicalizer. This changes:
1469//
1470// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1471// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1472//
1473// to the following:
1474//
1475// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1476// %cast = arith.index_cast %extract : i32 to index
1477//
1478// to just %element.
1479//
1480// Consider expanding this to a template and handle all tensor cast
1481// operations.
1482struct ExtractElementFromIndexCast
1483 : public OpRewritePattern<tensor::ExtractOp> {
1484 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1485
1486 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1487 PatternRewriter &rewriter) const final {
1488 Location loc = extract.getLoc();
1489 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1490 if (!indexCast)
1491 return failure();
1492
1493 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1494
1495 auto newExtract = tensor::ExtractOp::create(
1496 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1497
1498 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1499 newExtract);
1500
1501 return success();
1502 }
1503};
1504
1505} // namespace
1506
1507void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1508 MLIRContext *context) {
1509 results.add<ExtractElementFromIndexCast>(context);
1510}
1511
1512//===----------------------------------------------------------------------===//
1513// GatherOp
1514//===----------------------------------------------------------------------===//
1515
1516void GatherOp::getAsmResultNames(
1517 function_ref<void(Value, StringRef)> setNameFn) {
1518 setNameFn(getResult(), "gather");
1519}
1520
1521/// Return the inferred result type for a gatherOp where:
1522/// - sourceType is the type of the source tensor gathered from
1523/// - indicesType is the type of the indices used to gather
1524/// - gatherDims are the dims along which the gather occurs.
1525/// Return a full rank or ranked-reduced variant of the type depending on
1526/// the value of rankReduced.
1527///
1528/// The leading dimensions of the index tensor give the result tensor its
1529/// leading dimensions.
1530/// The trailing dimensions of the result tensor are obtained from the source
1531/// tensor by setting the dimensions specified in gather_dims to `1` (if
1532/// rankedReduced is false), or skipping them (otherwise).
1533RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1534 RankedTensorType indicesType,
1535 ArrayRef<int64_t> gatherDims,
1536 bool rankReduced) {
1537 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1538 resultShape.reserve(resultShape.size() + sourceType.getRank());
1539 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1540 if (llvm::binary_search(gatherDims, idx)) {
1541 if (!rankReduced)
1542 resultShape.push_back(1);
1543 continue;
1544 }
1545 resultShape.push_back(sourceType.getDimSize(idx));
1546 }
1547 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1548}
1549
1550static LogicalResult
1553 StringRef gatherOrScatter, StringRef sourceOrDest) {
1554 if (dims.empty())
1555 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1556
1557 int64_t numGatherDims = dims.size();
1558 if (numGatherDims > rank)
1559 return op->emitOpError(gatherOrScatter)
1560 << "_dims overflow " << sourceOrDest << " rank";
1561 if (indices.empty() || indices.back() != numGatherDims)
1562 return op->emitOpError(gatherOrScatter)
1563 << "_dims length must match the size of last dimension of indices";
1564 for (int64_t val : dims) {
1565 if (val < 0)
1566 return op->emitOpError(gatherOrScatter)
1567 << "_dims value must be non-negative";
1568 if (val >= rank)
1569 return op->emitOpError(gatherOrScatter)
1570 << "_dims value must be smaller than " << sourceOrDest << " rank";
1571 }
1572 for (int64_t i = 1; i < numGatherDims; ++i) {
1573 if (dims[i - 1] >= dims[i])
1574 return op->emitOpError(gatherOrScatter)
1575 << "_dims values must be strictly increasing";
1576 }
1577 return success();
1578}
1579
1580LogicalResult GatherOp::verify() {
1581 int64_t sourceRank = getSourceType().getRank();
1582 ArrayRef<int64_t> gatherDims = getGatherDims();
1583 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1584 getIndicesType().getShape(), sourceRank,
1585 "gather", "source")))
1586 return failure();
1587
1588 RankedTensorType expectedResultType = GatherOp::inferResultType(
1589 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1590 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1591 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1592 if (getResultType() != expectedResultType &&
1593 getResultType() != expectedRankReducedResultType) {
1594 return emitOpError("result type "
1595 "mismatch: "
1596 "expected ")
1597 << expectedResultType << " or its rank-reduced variant "
1598 << expectedRankReducedResultType << " (got: " << getResultType()
1599 << ")";
1600 }
1601
1602 return success();
1603}
1604
1605OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1606 if (OpFoldResult reshapedSource = reshapeConstantSource(
1607 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1608 getResult().getType()))
1609 return reshapedSource;
1610 return {};
1611}
1612
1613//===----------------------------------------------------------------------===//
1614// InsertOp
1615//===----------------------------------------------------------------------===//
1616
1617void InsertOp::getAsmResultNames(
1618 function_ref<void(Value, StringRef)> setNameFn) {
1619 setNameFn(getResult(), "inserted");
1620}
1621
1622LogicalResult InsertOp::verify() {
1623 // Verify the # indices match if we have a ranked type.
1624 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1625 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1626 return emitOpError("incorrect number of indices");
1627 return success();
1628}
1629
1630OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1631 Attribute scalar = adaptor.getScalar();
1632 Attribute dest = adaptor.getDest();
1633 if (scalar && dest)
1634 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1635 if (scalar == splatDest.getSplatValue<Attribute>())
1636 return dest;
1637 return {};
1638}
1639
1640//===----------------------------------------------------------------------===//
1641// GenerateOp
1642//===----------------------------------------------------------------------===//
1643
1644void GenerateOp::getAsmResultNames(
1645 function_ref<void(Value, StringRef)> setNameFn) {
1646 setNameFn(getResult(), "generated");
1647}
1648
1649LogicalResult GenerateOp::reifyResultShapes(
1650 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1651 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1652 int idx = 0;
1653 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1654 if (getType().isDynamicDim(dim)) {
1655 reifiedReturnShapes[0][dim] = getOperand(idx++);
1656 } else {
1657 reifiedReturnShapes[0][dim] =
1658 builder.getIndexAttr(getType().getDimSize(dim));
1659 }
1660 }
1661 return success();
1662}
1663
1664LogicalResult GenerateOp::verify() {
1665 // Ensure that the tensor type has as many dynamic dimensions as are
1666 // specified by the operands.
1667 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1668 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1669 getOperands())))
1670 return failure();
1671 return success();
1672}
1673
1674LogicalResult GenerateOp::verifyRegions() {
1675 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1676 // Ensure that region arguments span the index space.
1677 if (!llvm::all_of(getBody().getArgumentTypes(),
1678 [](Type ty) { return ty.isIndex(); }))
1679 return emitError("all body arguments must be index");
1680 if (getBody().getNumArguments() != resultTy.getRank())
1681 return emitError("must have one body argument per input dimension");
1682
1683 // Ensure that the region yields an element of the right type.
1684 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1685
1686 if (yieldOp.getValue().getType() != resultTy.getElementType())
1687 return emitOpError(
1688 "body must be terminated with a `yield` operation of the tensor "
1689 "element type");
1690
1691 return success();
1692}
1693
1694void GenerateOp::build(
1695 OpBuilder &b, OperationState &result, Type resultTy,
1696 ValueRange dynamicExtents,
1697 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1698 build(b, result, resultTy, dynamicExtents);
1699
1700 // Build and populate body.
1701 OpBuilder::InsertionGuard guard(b);
1702 Region *bodyRegion = result.regions.front().get();
1703 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1704 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1705 SmallVector<Location, 2> argumentLocs(rank, result.location);
1706 Block *bodyBlock =
1707 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1708 bodyBuilder(b, result.location, bodyBlock->getArguments());
1709}
1710
1711namespace {
1712
1713/// Canonicalizes tensor.generate operations with a constant
1714/// operand into the equivalent operation with the operand expressed in the
1715/// result type, instead. We also insert a type cast to make sure that the
1716/// resulting IR is still well-typed.
1717struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1718 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1719
1720 LogicalResult matchAndRewrite(GenerateOp generateOp,
1721 PatternRewriter &rewriter) const final {
1722 SmallVector<Value> foldedDynamicSizes;
1723 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1724 generateOp.getType(), generateOp.getDynamicExtents(),
1725 foldedDynamicSizes);
1726
1727 // Stop here if no dynamic size was promoted to static.
1728 if (foldedTensorType == generateOp.getType())
1729 return failure();
1730
1731 auto loc = generateOp.getLoc();
1732 auto newOp =
1733 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1734 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1735 newOp.getBody().begin());
1736 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1737 generateOp.getType(), newOp);
1738 return success();
1739 }
1740};
1741
1742/// Canonicalizes the pattern of the form
1743///
1744/// %tensor = tensor.generate %x {
1745/// ^bb0(%arg0: index):
1746/// <computation>
1747/// yield %1 : index
1748/// } : tensor<?xindex>
1749/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1750///
1751/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1752/// tensor.generate operation has no side-effects.
1753struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1754 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1755
1756 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1757 PatternRewriter &rewriter) const final {
1758 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1759 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1760 return failure();
1761
1762 IRMapping mapping;
1763 Block *body = &tensorFromElements.getBody().front();
1764 mapping.map(body->getArguments(), extract.getIndices());
1765 for (auto &op : body->without_terminator())
1766 rewriter.clone(op, mapping);
1767
1768 auto yield = cast<YieldOp>(body->getTerminator());
1769
1770 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1771 return success();
1772 }
1773};
1774
1775} // namespace
1776
1777void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1778 MLIRContext *context) {
1779 // TODO: Move extract pattern to tensor::ExtractOp.
1780 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1781}
1782
1783//===----------------------------------------------------------------------===//
1784// RankOp
1785//===----------------------------------------------------------------------===//
1786
1787void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1788 setNameFn(getResult(), "rank");
1789}
1790
1791OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1792 // Constant fold rank when the rank of the operand is known.
1793 auto type = getOperand().getType();
1794 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1795 if (shapedType && shapedType.hasRank())
1796 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1797 return IntegerAttr();
1798}
1799
1800//===----------------------------------------------------------------------===//
1801// ReshapeOp
1802//===----------------------------------------------------------------------===//
1803
1804void ReshapeOp::getAsmResultNames(
1805 function_ref<void(Value, StringRef)> setNameFn) {
1806 setNameFn(getResult(), "reshape");
1807}
1808
1809static int64_t getNumElements(ShapedType type) {
1810 int64_t numElements = 1;
1811 for (auto dim : type.getShape())
1812 numElements *= dim;
1813 return numElements;
1814}
1815
1816LogicalResult ReshapeOp::verify() {
1817 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1818 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1819
1820 if (operandType.getElementType() != resultType.getElementType())
1821 return emitOpError("element types of source and destination tensor "
1822 "types should be the same");
1823
1824 int64_t shapeSize =
1825 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1826 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1827 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1828
1829 if (resultRankedType) {
1830 if (operandRankedType && resultRankedType.hasStaticShape() &&
1831 operandRankedType.hasStaticShape()) {
1832 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1833 return emitOpError("source and destination tensor should have the "
1834 "same number of elements");
1835 }
1836 if (ShapedType::isDynamic(shapeSize))
1837 return emitOpError("cannot use shape operand with dynamic length to "
1838 "reshape to statically-ranked tensor type");
1839 if (shapeSize != resultRankedType.getRank())
1840 return emitOpError(
1841 "length of shape operand differs from the result's tensor rank");
1842 }
1843 return success();
1844}
1845
1846OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1847 if (OpFoldResult reshapedSource = reshapeConstantSource(
1848 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1849 getResult().getType()))
1850 return reshapedSource;
1851
1852 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1853 // producer's input instead as the original tensor to reshape. This could
1854 // render such producer dead code.
1855 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1856 getSourceMutable().assign(reshapeOpProducer.getSource());
1857 return getResult();
1858 }
1859
1860 auto source = getSource();
1861 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1862 auto resultTy = dyn_cast<RankedTensorType>(getType());
1863 if (!sourceTy || !resultTy || sourceTy != resultTy)
1864 return {};
1865
1866 // If the source and result are both 0D or 1D tensors and have the same type,
1867 // the reshape has no effect, even if the tensor is dynamically shaped.
1868 if (sourceTy.getRank() <= 1)
1869 return source;
1870
1871 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1872 auto elements = fromElements.getElements();
1873 bool dynamicNoop =
1874 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1875 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1876 auto element = elements[id];
1877
1878 if (auto cst = getConstantIntValue(element)) {
1879 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1880 continue;
1881 }
1882
1883 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1884 dynamicNoop &= dimOp.getSource() == source;
1885
1886 auto cst = getConstantIntValue(dimOp.getIndex());
1887 dynamicNoop &=
1888 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1889 continue;
1890 }
1891
1892 dynamicNoop = false;
1893 break;
1894 }
1895
1896 if (dynamicNoop)
1897 return source;
1898 }
1899
1900 return {};
1901}
1902
1903//===----------------------------------------------------------------------===//
1904// Reassociative reshape ops
1905//===----------------------------------------------------------------------===//
1906
1907void CollapseShapeOp::getAsmResultNames(
1908 function_ref<void(Value, StringRef)> setNameFn) {
1909 setNameFn(getResult(), "collapsed");
1910}
1911
1912void ExpandShapeOp::getAsmResultNames(
1913 function_ref<void(Value, StringRef)> setNameFn) {
1914 setNameFn(getResult(), "expanded");
1915}
1916
1917int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1918 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1919 "invalid resultDim");
1920 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1921 if (llvm::is_contained(it.value(), resultDim))
1922 return it.index();
1923 llvm_unreachable("could not find reassociation group");
1924}
1925
1926FailureOr<SmallVector<OpFoldResult>>
1927ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1928 RankedTensorType expandedType,
1929 ArrayRef<ReassociationIndices> reassociation,
1930 ArrayRef<OpFoldResult> inputShape) {
1931 std::optional<SmallVector<OpFoldResult>> outputShape =
1932 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1933 inputShape);
1934 if (!outputShape)
1935 return failure();
1936 return *outputShape;
1937}
1938
1939SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1940 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1941}
1942
1943void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1944 Type resultType, Value src,
1945 ArrayRef<ReassociationIndices> reassociation,
1946 ArrayRef<OpFoldResult> outputShape) {
1947 auto [staticOutputShape, dynamicOutputShape] =
1948 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1949 build(builder, result, cast<RankedTensorType>(resultType), src,
1950 getReassociationIndicesAttribute(builder, reassociation),
1951 dynamicOutputShape, staticOutputShape);
1952}
1953
1954void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1955 Type resultType, Value src,
1956 ArrayRef<ReassociationIndices> reassociation) {
1957 SmallVector<OpFoldResult> inputShape =
1958 getMixedSizes(builder, result.location, src);
1959 auto tensorResultTy = cast<RankedTensorType>(resultType);
1960 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1961 builder, result.location, tensorResultTy, reassociation, inputShape);
1962 SmallVector<OpFoldResult> outputShapeOrEmpty;
1963 if (succeeded(outputShape)) {
1964 outputShapeOrEmpty = *outputShape;
1965 }
1966 build(builder, result, tensorResultTy, src, reassociation,
1967 outputShapeOrEmpty);
1968}
1969
1970SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1971 return getSymbolLessAffineMaps(getReassociationExprs());
1972}
1973SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1975 getReassociationIndices());
1976}
1977
1978SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1979 return getSymbolLessAffineMaps(getReassociationExprs());
1980}
1981SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1983 getReassociationIndices());
1984}
1985
1986RankedTensorType CollapseShapeOp::inferCollapsedType(
1987 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1988 return inferCollapsedType(
1990 type.getContext(), reassociation)));
1991}
1992
1993/// Compute the RankedTensorType obtained by applying `reassociation` to
1994/// `type`.
1995RankedTensorType
1996CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1997 ArrayRef<AffineMap> reassociation) {
1998 auto shape = type.getShape();
1999 SmallVector<int64_t, 4> newShape;
2000 newShape.reserve(reassociation.size());
2001
2002 // Use the fact that reassociation is valid to simplify the logic: only use
2003 // each map's rank.
2004 assert(isReassociationValid(reassociation) && "invalid reassociation");
2005 unsigned currentDim = 0;
2006 for (AffineMap m : reassociation) {
2007 unsigned dim = m.getNumResults();
2008 auto band = shape.slice(currentDim, dim);
2009 int64_t size = 1;
2010 if (llvm::is_contained(band, ShapedType::kDynamic))
2011 size = ShapedType::kDynamic;
2012 else
2013 for (unsigned d = 0; d < dim; ++d)
2014 size *= shape[currentDim + d];
2015 newShape.push_back(size);
2016 currentDim += dim;
2017 }
2018
2019 return RankedTensorType::get(newShape, type.getElementType());
2020}
2021
2022void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2023 ArrayRef<ReassociationIndices> reassociation,
2024 ArrayRef<NamedAttribute> attrs) {
2025 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2026 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2027 auto resultType =
2028 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2029 srcType.getEncoding());
2030 result.addAttribute(getReassociationAttrStrName(),
2031 getReassociationIndicesAttribute(b, reassociation));
2032 build(b, result, resultType, src, attrs);
2033}
2034
2035template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2037static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2040 if (failed(
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2042 return failure();
2043
2044 // Reshape must preserve the number of elements when statically known.
2045 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2046 int64_t expandedNumElements = expandedType.getNumElements();
2047 int64_t collapsedNumElements = collapsedType.getNumElements();
2048 if (expandedNumElements != collapsedNumElements) {
2049 return op.emitOpError("number of elements must be preserved: ")
2050 << expandedNumElements << " != " << collapsedNumElements;
2051 }
2052 }
2053
2054 auto maps = op.getReassociationMaps();
2055 RankedTensorType expectedType =
2056 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2057 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2058 return op.emitOpError("expected collapsed type to be ")
2059 << expectedType << ", but got " << collapsedType;
2060 return success();
2061}
2062
2063LogicalResult ExpandShapeOp::verify() {
2064 RankedTensorType srcType = getSrc().getType();
2065 RankedTensorType resultType = getResult().getType();
2066
2067 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2068 return emitOpError("expected number of static shape dims to be equal to "
2069 "the output rank (")
2070 << resultType.getRank() << ") but found "
2071 << getStaticOutputShape().size() << " inputs instead";
2072
2073 if ((int64_t)getOutputShape().size() !=
2074 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2075 return emitOpError("mismatch in dynamic dims in output_shape and "
2076 "static_output_shape: static_output_shape has ")
2077 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2078 << " dynamic dims while output_shape has " << getOutputShape().size()
2079 << " values";
2080
2081 return verifyTensorReshapeOp(*this, resultType, srcType);
2082}
2083
2084LogicalResult CollapseShapeOp::verify() {
2085 CollapseShapeOp op = *this;
2086 if (llvm::any_of(op.getReassociationIndices(),
2087 [](ReassociationIndices group) { return group.empty(); })) {
2088 return op.emitOpError("reassociation indices must not be empty");
2089 }
2090 RankedTensorType srcType = op.getSrc().getType();
2091 RankedTensorType resultType = op.getResult().getType();
2092
2093 return verifyTensorReshapeOp(op, srcType, resultType);
2094}
2095
2096namespace {
2097/// Reshape of a splat constant can be replaced with a constant of the result
2098/// type.
2099template <typename TensorReshapeOp>
2100struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2101 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2102 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2103 PatternRewriter &rewriter) const override {
2104 DenseElementsAttr attr;
2105 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2106 return failure();
2107 if (!attr || !attr.isSplat())
2108 return failure();
2109 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2110 reshapeOp.getResultType(), attr.getRawData());
2111 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2112 return success();
2113 }
2114};
2115
2116// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2117template <typename TensorReshapeOp>
2118class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2119public:
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121
2122 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2123 PatternRewriter &rewriter) const override {
2124 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2125 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2126 return failure();
2127
2128 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2129 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2130 return success();
2131 }
2132};
2133
2134/// Reshape of a FromElements can be replaced with a FromElements of the
2135/// result type
2136template <typename TensorReshapeOp>
2137struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2138 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2139 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2140 PatternRewriter &rewriter) const override {
2141 auto fromElements =
2142 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2143 if (!fromElements)
2144 return failure();
2145
2146 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2147
2148 if (!shapedTy.hasStaticShape())
2149 return failure();
2150
2151 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2152 fromElements.getElements());
2153 return success();
2154 }
2155};
2156
2157// Fold CastOp into CollapseShapeOp when adding static information.
2158struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2159 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2160
2161 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2162 PatternRewriter &rewriter) const override {
2163 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2164 if (!tensor::canFoldIntoConsumerOp(castOp))
2165 return failure();
2166
2167 RankedTensorType srcType =
2168 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2169 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2170 srcType, collapseShapeOp.getReassociationMaps());
2171
2172 if (newResultType == collapseShapeOp.getResultType()) {
2173 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2174 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2175 });
2176 } else {
2177 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2178 newResultType, castOp.getSource(),
2179 collapseShapeOp.getReassociation());
2180 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2181 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2182 }
2183 return success();
2184 }
2185};
2186
2187/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2188/// matching constant output_shape operands of the expand. This makes the
2189/// `tensor.expand_shape` more static and creates a consumer cast that can be
2190/// propagated further.
2191struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2192 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2193
2194 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2195 PatternRewriter &rewriter) const override {
2196 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2197 if (!canFoldIntoConsumerOp(castOp))
2198 return failure();
2199
2200 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2201 SmallVector<ReassociationIndices, 4> reassoc =
2202 expandOp.getReassociationIndices();
2203
2204 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2205 SmallVector<Value> dynamicOutputShape;
2206 auto outputIt = expandOp.getOutputShape().begin();
2207
2208 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2209 for (uint64_t outDim : innerReassoc) {
2210 if (ShapedType::isStatic(newOutputShape[outDim]))
2211 continue;
2212
2213 // If the cast's src type is dynamic, don't infer any of the
2214 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2215 // least one of the expanded dimensions to be dynamic if the input is
2216 // dynamic.
2217 Value val = *outputIt;
2218 ++outputIt;
2219 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2220 dynamicOutputShape.push_back(val);
2221 continue;
2222 }
2223
2224 APInt cst;
2225 if (matchPattern(val, m_ConstantInt(&cst))) {
2226 newOutputShape[outDim] = cst.getSExtValue();
2227 } else {
2228 dynamicOutputShape.push_back(val);
2229 }
2230 }
2231 }
2232
2233 // Couldn't match any values, nothing to change
2234 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2235 return failure();
2236
2237 // Calculate the input shape from the output
2238 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2239 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2240 for (auto outDim : reassoc[inDim]) {
2241 auto ofr = newOutputShape[outDim];
2242 if (ShapedType::isDynamic(ofr)) {
2243 newInputShape[inDim] = ShapedType::kDynamic;
2244 break;
2245 }
2246 newInputShape[inDim] *= ofr;
2247 }
2248 }
2249
2250 SmallVector<OpFoldResult> outputOfr =
2251 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2252 auto inputType = RankedTensorType::get(
2253 newInputShape, expandOp.getSrcType().getElementType());
2254 auto outputType = RankedTensorType::get(
2255 newOutputShape, expandOp.getSrcType().getElementType());
2256 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2257 expandOp.getSrc());
2258 auto newExpand = ExpandShapeOp::create(
2259 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2260 expandOp.getReassociationIndices(), outputOfr);
2261 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2262 newExpand.getResult());
2263 return success();
2264 }
2265};
2266} // namespace
2267
2268void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2269 MLIRContext *context) {
2270 results.add<
2271 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2272 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2273 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2274 FoldReshapeWithSplat<ExpandShapeOp>,
2275 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2276}
2277
2278void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2279 MLIRContext *context) {
2280 results.add<
2281 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2282 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2283 tensor::DimOp, RankedTensorType>,
2284 FoldReshapeWithConstant<CollapseShapeOp>,
2285 FoldReshapeWithSplat<CollapseShapeOp>,
2286 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2287 context);
2288}
2289
2290OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2292 adaptor.getOperands());
2293}
2294
2295OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2297 adaptor.getOperands());
2298}
2299
2300//===----------------------------------------------------------------------===//
2301// ExtractSliceOp
2302//===----------------------------------------------------------------------===//
2303
2304void ExtractSliceOp::getAsmResultNames(
2305 function_ref<void(Value, StringRef)> setNameFn) {
2306 setNameFn(getResult(), "extracted_slice");
2307}
2308
2309/// An extract_slice result type can be inferred, when it is not
2310/// rank-reduced, from the source type and the static representation of
2311/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2312RankedTensorType
2313ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2314 ArrayRef<int64_t> staticSizes) {
2315 // An extract_slice op may specify only a leading subset of offset/sizes/
2316 // strides in which case we complete with offset=0, sizes from memref type
2317 // and strides=1.
2318 assert(static_cast<int64_t>(staticSizes.size()) ==
2319 sourceTensorType.getRank() &&
2320 "unexpected staticSizes not equal to rank of source");
2321 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2322 sourceTensorType.getEncoding());
2323}
2324
2325// TODO: This uses neither offsets nor strides!
2326RankedTensorType
2327ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2328 ArrayRef<OpFoldResult> sizes) {
2329 SmallVector<int64_t> staticSizes;
2330 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2331
2332 assert(static_cast<int64_t>(staticSizes.size()) ==
2333 sourceTensorType.getRank() &&
2334 "unexpected staticSizes not equal to rank of source");
2335 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2336 sourceTensorType.getEncoding());
2337}
2338
2339/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2340/// number of sizes), drop as many size 1 as needed to produce an inferred
2341/// type with the desired rank.
2342///
2343/// Note that there may be multiple ways to compute this rank-reduced type:
2344/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2345///
2346/// To disambiguate, this function always drops the first 1 sizes occurrences.
2347RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2348 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2349 ArrayRef<int64_t> sizes) {
2350 // Type inferred in the absence of rank-reducing behavior.
2351 auto inferredType = llvm::cast<RankedTensorType>(
2352 inferResultType(sourceRankedTensorType, sizes));
2353 int rankDiff = inferredType.getRank() - desiredResultRank;
2354 if (rankDiff > 0) {
2355 auto shape = inferredType.getShape();
2356 llvm::SmallBitVector dimsToProject =
2357 getPositionsOfShapeOne(rankDiff, shape);
2358 SmallVector<int64_t> projectedShape;
2359 // Best effort rank-reducing: drop 1s in order.
2360 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2361 if (!dimsToProject.test(pos))
2362 projectedShape.push_back(shape[pos]);
2363 inferredType =
2364 RankedTensorType::get(projectedShape, inferredType.getElementType());
2365 }
2366 return inferredType;
2367}
2368
2369RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2370 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2371 ArrayRef<OpFoldResult> sizes) {
2372 SmallVector<int64_t> staticSizes;
2373 SmallVector<Value> dynamicSizes;
2374 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2375 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2376 desiredResultRank, sourceRankedTensorType, staticSizes);
2377}
2378
2379/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2380/// result type. If the type passed is nullptr, it is inferred.
2381void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2382 RankedTensorType resultType, Value source,
2383 ArrayRef<OpFoldResult> offsets,
2384 ArrayRef<OpFoldResult> sizes,
2385 ArrayRef<OpFoldResult> strides,
2386 ArrayRef<NamedAttribute> attrs) {
2387 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2388 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2389 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2390 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2391 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2392 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2393 // Structuring implementation this way avoids duplication between builders.
2394 if (!resultType) {
2395 resultType = llvm::cast<RankedTensorType>(
2396 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2397 }
2398 result.addAttributes(attrs);
2399 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2400 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2401 b.getDenseI64ArrayAttr(staticSizes),
2402 b.getDenseI64ArrayAttr(staticStrides));
2403}
2404
2405/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2406/// result type.
2407void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2408 ArrayRef<OpFoldResult> offsets,
2409 ArrayRef<OpFoldResult> sizes,
2410 ArrayRef<OpFoldResult> strides,
2411 ArrayRef<NamedAttribute> attrs) {
2412 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2413}
2414
2415/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2416/// a Range vector.
2417void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2418 ArrayRef<Range> ranges,
2419 ArrayRef<NamedAttribute> attrs) {
2420 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2421 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2422}
2423
2424/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2425/// the type passed is nullptr, it is inferred.
2426void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2427 RankedTensorType resultType, Value source,
2428 ValueRange offsets, ValueRange sizes,
2429 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2430 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2431 offsets, [](Value v) -> OpFoldResult { return v; });
2432 SmallVector<OpFoldResult> sizeValues =
2433 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2434 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2435 strides, [](Value v) -> OpFoldResult { return v; });
2436 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2437}
2438
2439/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2440void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2441 ValueRange offsets, ValueRange sizes,
2442 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2443 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2444}
2445
2447 Operation *op,
2448 RankedTensorType expectedType) {
2449 switch (result) {
2451 return success();
2453 return op->emitError("expected rank to be smaller or equal to ")
2454 << "the other rank. ";
2456 return op->emitError("expected type to be ")
2457 << expectedType << " or a rank-reduced version. (size mismatch) ";
2459 return op->emitError("expected element type to be ")
2460 << expectedType.getElementType();
2461 default:
2462 llvm_unreachable("unexpected extract_slice op verification result");
2463 }
2464}
2465
2466/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2467/// result type, offsets set to 0 and strides set to 1.
2468void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2469 RankedTensorType resultType, Value source,
2470 ArrayRef<OpFoldResult> sizes,
2471 ArrayRef<NamedAttribute> attrs) {
2472 Attribute zeroIdxAttr = b.getIndexAttr(0);
2473 Attribute oneIdxAttr = b.getIndexAttr(1);
2474 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2475 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2476 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2477}
2478
2479/// Verifier for ExtractSliceOp.
2480LogicalResult ExtractSliceOp::verify() {
2481 RankedTensorType sourceType = getSourceType();
2482
2483 // Verify result type against inferred type.
2484 RankedTensorType expectedType =
2485 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2488 return produceSliceErrorMsg(result, *this, expectedType);
2489
2490 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2491 // to the source tensor.
2492 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2493 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2494 getStaticStrides(), /*generateErrorMessage=*/true);
2495 if (!boundsResult.isValid)
2496 return getOperation()->emitError(boundsResult.errorMessage);
2497
2498 return success();
2499}
2500
2501llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2502 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2503}
2504
2505FailureOr<Value>
2506ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2507 ArrayRef<int64_t> desiredShape) {
2508 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2509 assert(sourceTensorType && "not a ranked tensor type");
2510 auto sourceShape = sourceTensorType.getShape();
2511 if (sourceShape.equals(desiredShape))
2512 return value;
2513 auto maybeRankReductionMask =
2514 mlir::computeRankReductionMask(sourceShape, desiredShape);
2515 if (!maybeRankReductionMask)
2516 return failure();
2518 b, loc, value,
2519 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2520}
2521
2522LogicalResult ExtractSliceOp::reifyResultShapes(
2523 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2524 reifiedReturnShapes.resize(1);
2525 reifiedReturnShapes[0].reserve(getType().getRank());
2526 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2527 llvm::SmallBitVector droppedDims = getDroppedDims();
2528 for (const auto &size : enumerate(mixedSizes)) {
2529 if (droppedDims.test(size.index()))
2530 continue;
2531 reifiedReturnShapes[0].push_back(size.value());
2532 }
2533 return success();
2534}
2535
2536namespace {
2537/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2538/// This essentially pushes memref_cast past its consuming slice when
2539/// `canFoldIntoConsumerOp` is true.
2540///
2541/// Example:
2542/// ```
2543/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2544/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2545/// tensor<3x4xf32>
2546/// ```
2547/// is rewritten into:
2548/// ```
2549/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2550/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2551/// ```
2552class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2553public:
2554 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2555
2556 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2557 PatternRewriter &rewriter) const override {
2558 // Any constant operand, just return to let the constant folder kick in.
2559 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2560 return matchPattern(operand, matchConstantIndex());
2561 }))
2562 return failure();
2563
2564 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2565 if (!castOp)
2566 return failure();
2567
2568 if (!canFoldIntoConsumerOp(castOp))
2569 return failure();
2570
2571 // Pattern does not apply if the produced op would not verify.
2572 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2573 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2574 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2575 sliceOp.getStaticStrides());
2576 if (!sliceResult.isValid)
2577 return failure();
2578
2579 // Create folded extract.
2580 Location loc = sliceOp.getLoc();
2581 Value newResult = ExtractSliceOp::create(
2582 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2583 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2584 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2585 sliceOp.getStaticStrides());
2586 rewriter.replaceOp(sliceOp, newResult);
2587 return success();
2588 }
2589};
2590
2591/// Slice elements from `values` into `outValues`. `counts` represents the
2592/// numbers of elements to stride in the original values for each dimension.
2593/// The output values can be used to construct a DenseElementsAttr.
2594template <typename IterTy, typename ElemTy>
2595static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2596 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2597 ArrayRef<int64_t> strides,
2598 llvm::SmallVectorImpl<ElemTy> *outValues) {
2599 assert(offsets.size() == sizes.size());
2600 assert(offsets.size() == strides.size());
2601 if (offsets.empty())
2602 return;
2603
2604 int64_t offset = offsets.front();
2605 int64_t size = sizes.front();
2606 int64_t stride = strides.front();
2607 if (offsets.size() == 1) {
2608 for (int64_t i = 0; i < size; ++i, offset += stride)
2609 outValues->push_back(*(values + offset));
2610
2611 return;
2612 }
2613
2614 for (int64_t i = 0; i < size; ++i, offset += stride) {
2615 auto begin = values + offset * counts.front();
2616 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2617 offsets.drop_front(), sizes.drop_front(),
2618 strides.drop_front(), outValues);
2619 }
2620}
2621
2622/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2623/// folded operation might introduce more constant data; Users can control
2624/// their heuristics by the control function.
2625class ConstantOpExtractSliceFolder final
2626 : public OpRewritePattern<ExtractSliceOp> {
2627public:
2628 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2629
2630 ConstantOpExtractSliceFolder(MLIRContext *context,
2632 : OpRewritePattern<ExtractSliceOp>(context),
2633 controlFn(std::move(controlFn)) {}
2634
2635 LogicalResult matchAndRewrite(ExtractSliceOp op,
2636 PatternRewriter &rewriter) const override {
2637 DenseElementsAttr attr;
2638 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2639 return failure();
2640
2641 // A constant splat is handled by fold().
2642 if (attr.isSplat())
2643 return failure();
2644
2645 // Dynamic result shape is not supported.
2646 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2647 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2648 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2649 return failure();
2650
2651 // Customized control over the folding.
2652 if (!controlFn(op))
2653 return failure();
2654
2655 int64_t count = sourceType.getNumElements();
2656 if (count == 0)
2657 return failure();
2658
2659 // Check if there are any dynamic parts, which are not supported.
2660 auto offsets = op.getStaticOffsets();
2661 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2662 return failure();
2663 auto sizes = op.getStaticSizes();
2664 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2665 return failure();
2666 auto strides = op.getStaticStrides();
2667 if (llvm::is_contained(strides, ShapedType::kDynamic))
2668 return failure();
2669
2670 // Compute the stride for each dimension.
2671 SmallVector<int64_t> counts;
2672 ArrayRef<int64_t> shape = sourceType.getShape();
2673 counts.reserve(shape.size());
2674 for (int64_t v : shape) {
2675 count = count / v;
2676 counts.push_back(count);
2677 }
2678
2679 // New attribute constructed by the sliced values.
2680 DenseElementsAttr newAttr;
2681
2682 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2683 SmallVector<APInt> outValues;
2684 outValues.reserve(sourceType.getNumElements());
2685 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2686 elems.begin(), counts, offsets, sizes, strides, &outValues);
2687 newAttr = DenseElementsAttr::get(resultType, outValues);
2688 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2689 SmallVector<APFloat> outValues;
2690 outValues.reserve(sourceType.getNumElements());
2691 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2692 elems.begin(), counts, offsets, sizes, strides, &outValues);
2693 newAttr = DenseElementsAttr::get(resultType, outValues);
2694 }
2695
2696 if (newAttr) {
2697 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2698 return success();
2699 }
2700
2701 return failure();
2702 }
2703
2704private:
2705 /// This additionally controls whether the fold happens or not. Users can
2706 /// impose their heuristics in the function.
2708};
2709
2710} // namespace
2711
2714 const ControlConstantExtractSliceFusionFn &controlFn) {
2715 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2716}
2717
2718/// Return the canonical type of the result of an extract_slice op.
2720 RankedTensorType operator()(ExtractSliceOp op,
2721 ArrayRef<OpFoldResult> mixedOffsets,
2722 ArrayRef<OpFoldResult> mixedSizes,
2723 ArrayRef<OpFoldResult> mixedStrides) {
2724 // Infer a tensor type without taking into account any rank reductions.
2725 RankedTensorType nonReducedType =
2726 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2727
2728 // Directly return the non-rank reduced type if there are no dropped
2729 // dims.
2730 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2731 if (droppedDims.none())
2732 return nonReducedType;
2733
2734 // Build the reduced shape, preserving the original rank reduction pattern.
2735 SmallVector<int64_t> targetShape;
2736 for (auto i : llvm::seq<int64_t>(mixedSizes.size()))
2737 if (!droppedDims.test(i))
2738 targetShape.push_back(nonReducedType.getDimSize(i));
2739
2740 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2741 nonReducedType.getEncoding());
2742 }
2743};
2744
2745/// A canonicalizer wrapper to replace ExtractSliceOps.
2747 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2748 ExtractSliceOp newOp) {
2749 Value replacement = newOp.getResult();
2750 if (replacement.getType() != op.getType())
2751 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2752 replacement);
2753 rewriter.replaceOp(op, replacement);
2754 }
2755};
2756
2757void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2758 MLIRContext *context) {
2759 results.add<
2760 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2761 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2762 ExtractSliceOpCastFolder>(context);
2763}
2764
2765//
2766static LogicalResult
2767foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2768 ShapedType shapedType) {
2769 OpBuilder b(op.getContext());
2770 for (OpFoldResult ofr : op.getMixedOffsets())
2771 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2772 return failure();
2773 // Rank-reducing noops only need to inspect the leading dimensions:
2774 // llvm::zip is appropriate.
2775 auto shape = shapedType.getShape();
2776 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2777 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2778 return failure();
2779 for (OpFoldResult ofr : op.getMixedStrides())
2780 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2781 return failure();
2782 return success();
2783}
2784
2785/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2786/// slice, we can return the InsertSliceOp's source directly.
2787// TODO: This only checks the immediate producer; extend to go up the
2788// insert/extract chain if the slices are disjoint.
2789static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2790 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2791
2792 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2793 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2794 insertOp.isSameAs(extractOp, isSame))
2795 return insertOp.getSource();
2796
2797 return {};
2798}
2799
2800OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2801 if (OpFoldResult reshapedSource = reshapeConstantSource(
2802 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2803 getResult().getType()))
2804 return reshapedSource;
2805 if (getSourceType() == getType() &&
2807 return this->getSource();
2808 if (Value slice = foldExtractAfterInsertSlice(*this))
2809 return slice;
2810
2811 return OpFoldResult();
2812}
2813
2815 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2816 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2817 unsigned rank = rankedTensorType.getRank();
2818 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2820 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2821 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2822 offsets, sizes, strides);
2823}
2824
2825//===----------------------------------------------------------------------===//
2826// InsertSliceOp
2827//===----------------------------------------------------------------------===//
2828
2829void InsertSliceOp::getAsmResultNames(
2830 function_ref<void(Value, StringRef)> setNameFn) {
2831 setNameFn(getResult(), "inserted_slice");
2832}
2833
2834// Build a InsertSliceOp with mixed static and dynamic entries.
2835void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2836 Value dest, ArrayRef<OpFoldResult> offsets,
2838 ArrayRef<OpFoldResult> strides,
2840 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2841 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2842 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2843 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2844 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2845 result.addAttributes(attrs);
2846 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2847 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2848 b.getDenseI64ArrayAttr(staticSizes),
2849 b.getDenseI64ArrayAttr(staticStrides));
2850}
2851
2852/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2853/// Range vector.
2854void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2855 Value dest, ArrayRef<Range> ranges,
2856 ArrayRef<NamedAttribute> attrs) {
2857 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2858 build(b, result, source, dest, offsets, sizes, strides, attrs);
2859}
2860
2861// Build a InsertSliceOp with dynamic entries.
2862void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2863 Value dest, ValueRange offsets, ValueRange sizes,
2864 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2865 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2866 offsets, [](Value v) -> OpFoldResult { return v; });
2867 SmallVector<OpFoldResult> sizeValues =
2868 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2869 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2870 strides, [](Value v) -> OpFoldResult { return v; });
2871 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2872}
2873
2874/// Rank-reducing type verification for both InsertSliceOp and
2875/// ParallelInsertSliceOp.
2877 RankedTensorType srcType, RankedTensorType dstType,
2878 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2879 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2880 // insert_slice is the inverse of extract_slice, use the same type
2881 // inference.
2882 RankedTensorType expected =
2883 ExtractSliceOp::inferResultType(dstType, staticSizes);
2884 if (expectedType)
2885 *expectedType = expected;
2886 return isRankReducedType(expected, srcType);
2887}
2888
2889/// Verifier for InsertSliceOp.
2890LogicalResult InsertSliceOp::verify() {
2891 // Verify result type against inferred type.
2892 RankedTensorType expectedType;
2894 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2895 getStaticSizes(), getStaticStrides(), &expectedType);
2897 return produceSliceErrorMsg(result, *this, expectedType);
2898
2899 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2900 // to the destination tensor.
2901 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2902 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2903 getStaticStrides(), /*generateErrorMessage=*/true);
2904 if (!boundsResult.isValid)
2905 return getOperation()->emitError(boundsResult.errorMessage);
2906
2907 return success();
2908}
2909
2910/// If we have two consecutive InsertSliceOp writing to the same slice, we
2911/// can mutate the second InsertSliceOp's destination to the first one's.
2912///
2913/// Example:
2914///
2915/// ```mlir
2916/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2917/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2918/// ```
2919///
2920/// folds into:
2921///
2922/// ```mlir
2923/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2924/// ```
2925///
2926/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2927static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2928 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2929
2930 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2931 if (!prevInsertOp ||
2932 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2933 !prevInsertOp.isSameAs(insertOp, isSame))
2934 return failure();
2935
2936 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2937 return success();
2938}
2939
2940/// Folds round-trip extract/insert slice op pairs.
2941/// Example:
2942/// ```mlir
2943/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2944/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2945/// ```
2946/// can be folded into %val.
2947static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2948 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2949
2950 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2951 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2952 !extractOp.isSameAs(insertOp, isSame))
2953 return nullptr;
2954
2955 return extractOp.getSource();
2956}
2957
2958OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2959 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2960 getSourceType() == getType() &&
2962 return this->getSource();
2963 if (succeeded(foldInsertAfterInsertSlice(*this)))
2964 return getResult();
2965 if (auto result = foldInsertAfterExtractSlice(*this))
2966 return result;
2967 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2968 return getDest();
2969 return OpFoldResult();
2970}
2971
2972LogicalResult InsertSliceOp::reifyResultShapes(
2973 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2974 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2975 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2976 return success();
2977}
2978
2979namespace {
2980/// Pattern to rewrite a insert_slice op with constant arguments.
2981///
2982/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2983template <typename InsertOpTy>
2984class InsertSliceOpConstantArgumentFolder final
2985 : public OpRewritePattern<InsertOpTy> {
2986public:
2987 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2988
2989 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2990 PatternRewriter &rewriter) const override {
2991 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2992 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2993 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2994
2995 // No constant operands were folded, just return;
2996 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2997 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2998 failed(foldDynamicStrideList(mixedStrides)))
2999 return failure();
3000
3001 // Pattern does not apply if the produced op would not verify.
3002 SliceBoundsVerificationResult sliceResult =
3003 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
3004 mixedOffsets, mixedSizes, mixedStrides);
3005 if (!sliceResult.isValid)
3006 return failure();
3007
3008 // Create the new op in canonical form.
3009 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3010 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3011 mixedSizes);
3012 Value toInsert = insertSliceOp.getSource();
3013 if (sourceType != insertSliceOp.getSourceType()) {
3014 OpBuilder::InsertionGuard g(rewriter);
3015 // The only difference between InsertSliceOp and ParallelInsertSliceOp
3016 // is that the insertion point is just before the InParallelOp in
3017 // the parallel case.
3018 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3019 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3020 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3021 sourceType, toInsert);
3022 }
3023 rewriter.replaceOpWithNewOp<InsertOpTy>(
3024 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3025 mixedSizes, mixedStrides);
3026 return success();
3027 }
3028};
3029
3030/// Fold tensor_casts with insert_slice operations. If the source or
3031/// destination tensor is a tensor_cast that removes static type information,
3032/// the cast is folded into the insert_slice operation. E.g.:
3033///
3034/// ```mlir
3035/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3036/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3037/// ```
3038///
3039/// folds into:
3040///
3041/// ```mlir
3042/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3043/// ```
3044///
3045/// Note: When folding a cast on the destination tensor, the result of the
3046/// insert_slice operation is casted to ensure that the type of the result did
3047/// not change.
3048///
3049/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3050template <typename InsertOpTy>
3051struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3052 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3053
3054 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3055 PatternRewriter &rewriter) const override {
3056 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3057 return matchPattern(operand, matchConstantIndex());
3058 }))
3059 return failure();
3060
3061 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3062 auto castOp = v.getDefiningOp<tensor::CastOp>();
3063 if (!castOp || !canFoldIntoConsumerOp(castOp))
3064 return std::nullopt;
3065 return castOp.getSource();
3066 };
3067 std::optional<Value> sourceCastSource =
3068 getSourceOfCastOp(insertSliceOp.getSource());
3069 std::optional<Value> destCastSource =
3070 getSourceOfCastOp(insertSliceOp.getDest());
3071 if (!sourceCastSource && !destCastSource)
3072 return failure();
3073
3074 auto src =
3075 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3076 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3077 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3078 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3079 if (!srcType || !dstType)
3080 return failure();
3081
3082 // The tensor.cast source could have additional static information not seen
3083 // in the insert slice op static sizes, so we ignore dynamic dims when
3084 // computing the rank reduction mask.
3085 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3086 auto rankReductionMask = computeRankReductionMask(
3087 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3088 if (!rankReductionMask.has_value())
3089 return failure();
3090 // Replace dimensions in the insert slice op with corresponding static dims
3091 // from the cast source type. If the insert slice sizes have static dims
3092 // that are not static in the tensor.cast source (i.e., when the cast op
3093 // casts a dynamic dim to static), the dim should not be replaced, and the
3094 // pattern will fail later in `verifyInsertSliceOp`.
3095 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3096 int64_t rankReducedIdx = 0;
3097 for (auto [idx, size] : enumerate(staticSizes)) {
3098 if (!rankReductionMask.value().contains(idx) &&
3099 !srcType.isDynamicDim(rankReducedIdx)) {
3100 mixedSizes[idx] = getAsIndexOpFoldResult(
3101 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3102 size = srcType.getDimSize(rankReducedIdx++);
3103 }
3104 }
3105
3106 // Pattern does not apply if the produced op would not verify.
3107 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3108 staticSizes, insertSliceOp.getStaticStrides()) !=
3109 SliceVerificationResult::Success)
3110 return failure();
3111 SliceBoundsVerificationResult sliceResult =
3112 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3113 mixedSizes, insertSliceOp.getMixedStrides());
3114 if (!sliceResult.isValid)
3115 return failure();
3116
3117 Operation *replacement =
3118 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3119 insertSliceOp.getMixedOffsets(), mixedSizes,
3120 insertSliceOp.getMixedStrides());
3121
3122 // In the parallel case there is no result and so nothing to cast.
3123 bool isParallelInsert =
3124 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3125 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3126 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3127 insertSliceOp.getDestType(),
3128 replacement->getResult(0));
3129 }
3130 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3131 return success();
3132 }
3133};
3134
3135/// If additional static type information can be deduced from a insert_slice's
3136/// size operands, insert an explicit cast of the op's source operand. This
3137/// enables other canonicalization patterns that are matching for tensor_cast
3138/// ops such as `ForOpTensorCastFolder` in SCF.
3139///
3140/// Example:
3141///
3142/// ```mlir
3143/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3144/// : tensor<?x?xf32> into ...
3145/// ```
3146///
3147/// folds into:
3148///
3149/// ```mlir
3150/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3151/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3152/// : tensor<64x64xf32> into ...
3153/// ```
3154///
3155/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3156template <typename InsertOpTy>
3157struct InsertSliceOpSourceCastInserter final
3158 : public OpRewritePattern<InsertOpTy> {
3159 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3160
3161 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3162 PatternRewriter &rewriter) const override {
3163 RankedTensorType srcType = insertSliceOp.getSourceType();
3164 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3165 return failure();
3166 SmallVector<int64_t> newSrcShape(srcType.getShape());
3167 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3168 if (std::optional<int64_t> constInt =
3169 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3170 // Bail on invalid IR.
3171 if (*constInt < 0)
3172 return failure();
3173 newSrcShape[i] = *constInt;
3174 }
3175 }
3176 if (!hasValidSizesOffsets(newSrcShape))
3177 return failure();
3178
3179 RankedTensorType newSrcType = RankedTensorType::get(
3180 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3181 if (srcType == newSrcType ||
3182 !preservesStaticInformation(srcType, newSrcType) ||
3183 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3184 return failure();
3185
3186 // newSrcType is:
3187 // 1) Different from srcType.
3188 // 2) "More static" than srcType.
3189 // 3) Cast-compatible with srcType.
3190 // Insert the cast.
3191 OpBuilder::InsertionGuard g(rewriter);
3192 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3193 // that the insertion point is just before the InParallelOp in the
3194 // parallel case.
3195 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3196 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3197 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3198 newSrcType, insertSliceOp.getSource());
3199 rewriter.replaceOpWithNewOp<InsertOpTy>(
3200 insertSliceOp, cast, insertSliceOp.getDest(),
3201 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3202 insertSliceOp.getMixedStrides());
3203 return success();
3204 }
3205};
3206} // namespace
3207
3208llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3209 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3210}
3211
3212void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3213 MLIRContext *context) {
3214 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3215 InsertSliceOpCastFolder<InsertSliceOp>,
3216 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3217}
3218
3220 Location loc,
3221 Value tensor,
3222 Value dest) {
3223 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3224 unsigned rank = rankedTensorType.getRank();
3225 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3226 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3227 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3228 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3229 sizes, strides);
3230}
3231
3232//===----------------------------------------------------------------------===//
3233// PadOp
3234//===----------------------------------------------------------------------===//
3235
3236void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3237 setNameFn(getResult(), "padded");
3238}
3239
3240LogicalResult PadOp::verify() {
3241 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3242 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3243 auto expectedType =
3244 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3245 if (!expectedType) {
3246 return emitError("failed to infer expectedType from sourceType ")
3247 << sourceType << ", specified resultType is " << resultType;
3248 }
3249 if (resultType.getRank() != expectedType.getRank()) {
3250 return emitError("specified type ")
3251 << resultType << " does not match the inferred type "
3252 << expectedType;
3253 }
3254 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3255 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3256 continue;
3257 if (expectedType.isDynamicDim(i))
3258 continue;
3259 return emitError("specified type ")
3260 << resultType << " does not match the inferred type "
3261 << expectedType;
3262 }
3263
3264 return success();
3265}
3266
3267LogicalResult PadOp::verifyRegions() {
3268 auto &region = getRegion();
3269 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3270 Block &block = region.front();
3271 if (block.getNumArguments() != rank)
3272 return emitError("expected the block to have ") << rank << " arguments";
3273
3274 // Note: the number and type of yield values are checked in the YieldOp.
3275 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3276 if (!en.value().isIndex())
3277 return emitOpError("expected block argument ")
3278 << (en.index() + 1) << " to be an index";
3279 }
3280
3281 // Ensure that the region yields an element of the right type.
3282 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3283 if (yieldOp.getValue().getType() !=
3284 llvm::cast<ShapedType>(getType()).getElementType())
3285 return emitOpError("expected yield type to match shape element type");
3286
3287 return success();
3288}
3289
3290RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3291 ArrayRef<int64_t> staticLow,
3292 ArrayRef<int64_t> staticHigh,
3293 ArrayRef<int64_t> resultShape) {
3294 unsigned rank = sourceType.getRank();
3295 if (staticLow.size() != rank)
3296 return RankedTensorType();
3297 if (staticHigh.size() != rank)
3298 return RankedTensorType();
3299 if (!resultShape.empty() && resultShape.size() != rank)
3300 return RankedTensorType();
3301
3302 SmallVector<int64_t, 4> inferredShape;
3303 for (auto i : llvm::seq<unsigned>(0, rank)) {
3304 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3305 staticHigh[i] == ShapedType::kDynamic) {
3306 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3307 : resultShape[i]);
3308 } else {
3309 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3310 assert((resultShape.empty() || size == resultShape[i] ||
3311 resultShape[i] == ShapedType::kDynamic) &&
3312 "mismatch between inferred shape and result shape");
3313 inferredShape.push_back(size);
3314 }
3315 }
3316
3317 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3318}
3319
3320void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3321 Value source, ArrayRef<int64_t> staticLow,
3322 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3323 bool nofold, ArrayRef<NamedAttribute> attrs) {
3324 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3325 if (!resultType)
3326 resultType = inferResultType(sourceType, staticLow, staticHigh);
3327 result.addAttributes(attrs);
3328 build(b, result, resultType, source, low, high,
3329 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3330 nofold ? b.getUnitAttr() : UnitAttr());
3331}
3332
3333void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3334 Value source, ValueRange low, ValueRange high, bool nofold,
3335 ArrayRef<NamedAttribute> attrs) {
3336 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3337 unsigned rank = sourceType.getRank();
3338 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3339 build(b, result, resultType, source, staticVector, staticVector, low, high,
3340 nofold, attrs);
3341}
3342
3343void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3344 Value source, ArrayRef<OpFoldResult> low,
3345 ArrayRef<OpFoldResult> high, bool nofold,
3346 ArrayRef<NamedAttribute> attrs) {
3347 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3348 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3349 SmallVector<int64_t, 4> staticLow, staticHigh;
3350 // staticLow and staticHigh have full information of the padding config.
3351 // This will grow staticLow and staticHigh with 1 value. If the config is
3352 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3353 // value as well.
3354 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3355 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3356 if (!resultType) {
3357 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3358 }
3359 assert(llvm::isa<RankedTensorType>(resultType));
3360 result.addAttributes(attrs);
3361 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3362 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3363 nofold ? b.getUnitAttr() : UnitAttr());
3364}
3365
3366void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3367 Value source, ArrayRef<OpFoldResult> low,
3368 ArrayRef<OpFoldResult> high, Value constantPadValue,
3369 bool nofold, ArrayRef<NamedAttribute> attrs) {
3370 build(b, result, resultType, source, low, high, nofold, attrs);
3371
3372 // Add a region and a block to yield the pad value.
3373 Region *region = result.regions[0].get();
3374 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3375 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3376 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3377
3378 // `builder.createBlock` changes the insertion point within the block. Create
3379 // a guard to reset the insertion point of the builder after it is destroyed.
3380 OpBuilder::InsertionGuard guard(b);
3381 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3382 tensor::YieldOp::create(b, result.location, constantPadValue);
3383}
3384
3385llvm::SmallBitVector PadOp::getPaddedDims() {
3386 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3387 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3388 for (const auto &en : enumerate(paddingWidths))
3389 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3390 paddedDims.set(en.index());
3391 };
3392 extractPaddedDims(getMixedLowPad());
3393 extractPaddedDims(getMixedHighPad());
3394 return paddedDims;
3395}
3396
3397namespace {
3398// Folds tensor.pad when padding is static zeros and the attribute
3399// doesn't request otherwise.
3400struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3401 using OpRewritePattern<PadOp>::OpRewritePattern;
3402
3403 LogicalResult matchAndRewrite(PadOp padTensorOp,
3404 PatternRewriter &rewriter) const override {
3405 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3406 return failure();
3407 if (padTensorOp.getNofold())
3408 return failure();
3409 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3410 padTensorOp, padTensorOp.getResult().getType(),
3411 padTensorOp.getSource());
3412 return success();
3413 }
3414};
3415
3416// Fold CastOp into PadOp when adding static information.
3417struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3418 using OpRewritePattern<PadOp>::OpRewritePattern;
3419
3420 LogicalResult matchAndRewrite(PadOp padTensorOp,
3421 PatternRewriter &rewriter) const override {
3422 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3423 if (!tensor::canFoldIntoConsumerOp(castOp))
3424 return failure();
3425
3426 auto newResultType = PadOp::inferResultType(
3427 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3428 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3429 padTensorOp.getResultType().getShape());
3430
3431 if (newResultType == padTensorOp.getResultType()) {
3432 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3433 padTensorOp.getSourceMutable().assign(castOp.getSource());
3434 });
3435 } else {
3436 auto newOp = PadOp::create(
3437 rewriter, padTensorOp->getLoc(), newResultType,
3438 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3439 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3440 padTensorOp.getHigh(), padTensorOp.getNofold(),
3441 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3442 IRMapping mapper;
3443 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3444
3445 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3446 padTensorOp, padTensorOp.getResultType(), newOp);
3447 }
3448 return success();
3449 }
3450};
3451
3452// Fold CastOp using the result of PadOp back into the latter if it adds
3453// static information.
3454struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3455 using OpRewritePattern<PadOp>::OpRewritePattern;
3456
3457 LogicalResult matchAndRewrite(PadOp padTensorOp,
3458 PatternRewriter &rewriter) const override {
3459 if (!padTensorOp.getResult().hasOneUse())
3460 return failure();
3461 auto tensorCastOp =
3462 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3463 if (!tensorCastOp)
3464 return failure();
3465 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3466 tensorCastOp.getDest().getType()))
3467 return failure();
3468
3469 auto replacementOp = PadOp::create(
3470 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3471 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3472 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3473 padTensorOp.getHigh(), padTensorOp.getNofold(),
3474 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3475 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3476
3477 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3478 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3479 return success();
3480 }
3481};
3482
3483/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3484/// different dimensions. The pattern applies if the following preconditions
3485/// hold:
3486/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3487/// 2) the tensor::ExtractSliceOps have only unit-strides,
3488/// 3) the tensor::PadOps perform only high-padding,
3489/// 4) the tensor::PadOps have the same constant padding value,
3490/// 5) the tensor::PadOps do not have common padding dimensions,
3491/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3492/// zero-offset for every dimension.
3493/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3494/// the
3495/// padded source dimensions.
3496///
3497/// Example:
3498///
3499/// ```mlir
3500/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3501/// : tensor<64x64xf32> to tensor<?x64xf32>
3502/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3503/// } : tensor<?x64xf32> to tensor<8x64xf32>
3504/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3505/// : tensor<8x64xf32> to tensor<8x?xf32>
3506/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3507/// } : tensor<8x?xf32> to tensor<8x4xf32>
3508/// ```
3509///
3510/// folds into:
3511///
3512/// ```mlir
3513/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3514/// : tensor<64x64xf32> to tensor<?x?xf32>
3515/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3516/// } : tensor<?x?xf32> to tensor<8x4xf32>
3517/// ```
3518struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3519 using OpRewritePattern<PadOp>::OpRewritePattern;
3520
3521 LogicalResult matchAndRewrite(PadOp padOp,
3522 PatternRewriter &rewriter) const override {
3523 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3524 if (!innerSliceOp)
3525 return failure();
3526 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3527 if (!outerPadOp || outerPadOp.getNofold())
3528 return failure();
3529 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3530 if (!outerSliceOp)
3531 return failure();
3532
3533 // 1) Fail if the chain is rank-reducing.
3534 int64_t rank = padOp.getSourceType().getRank();
3535 if (outerSliceOp.getSourceType().getRank() != rank) {
3536 return rewriter.notifyMatchFailure(padOp,
3537 "cannot fold rank-reducing chain");
3538 }
3539
3540 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3541 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3542 return rewriter.notifyMatchFailure(
3543 padOp, "cannot fold non-unit stride ExtractSliceOps");
3544 }
3545
3546 // 3) Fail if the tensor::PadOps have non-zero low padding.
3547 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3548 return rewriter.notifyMatchFailure(padOp,
3549 "cannot fold PadOps with low padding");
3550 }
3551
3552 // 4) Fail if the tensor::PadOps padding values do not match.
3553 Attribute innerAttr, outerAttr;
3554 Value innerValue = padOp.getConstantPaddingValue();
3555 Value outerValue = outerPadOp.getConstantPaddingValue();
3556 if (!innerValue || !outerValue ||
3557 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3558 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3559 innerAttr != outerAttr) {
3560 return rewriter.notifyMatchFailure(
3561 padOp, "cannot fold PadOps with different padding values");
3562 }
3563
3564 // 5) Fail if a dimension is padded by both tensor::PadOps.
3565 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3566 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3567 if (innerDims.anyCommon(outerDims)) {
3568 return rewriter.notifyMatchFailure(
3569 padOp, "cannot fold PadOps with common padding dimensions");
3570 }
3571
3572 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3573 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3574 // for every dimension, and use the offset the other pair. Fail if no
3575 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3576 // exists.
3577 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3578 for (auto en : enumerate(newOffsets)) {
3579 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3580 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3581 if (!innerDims.test(en.index()) &&
3582 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3583 en.value() = outerOffset;
3584 continue;
3585 }
3586 if (!outerDims.test(en.index()) &&
3587 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3588 en.value() = innerOffset;
3589 continue;
3590 }
3591 return rewriter.notifyMatchFailure(
3592 padOp, "cannot find zero-offset and zero-padding pair");
3593 }
3594
3595 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3596 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3597 // outer tensor::PadOp and fail if the size of the inner
3598 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3599 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3600 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3601 for (auto en : enumerate(newSizes)) {
3602 if (!outerDims.test(en.index()))
3603 continue;
3604 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3605 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3606 assert(ShapedType::isStatic(sourceSize) &&
3607 "expected padded dimension to have a static size");
3608 if (getConstantIntValue(sliceSize) != sourceSize) {
3609 return rewriter.notifyMatchFailure(
3610 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3611 "match the size of the outer padding");
3612 }
3613 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3614 }
3615
3616 // Combine the high paddings of the two tensor::PadOps.
3617 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3618 for (auto en : enumerate(newHighPad)) {
3619 if (innerDims.test(en.index()))
3620 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3621 if (outerDims.test(en.index()))
3622 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3623 }
3624
3625 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3626 // the two paddings in one step.
3627 auto newSliceOp = ExtractSliceOp::create(
3628 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3629 newSizes, innerSliceOp.getMixedStrides());
3630 auto newPadOp = PadOp::create(
3631 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3632 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3633 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3634 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3635 newPadOp.getRegion().begin());
3636 rewriter.replaceOp(padOp, newPadOp.getResult());
3637 return success();
3638 }
3639};
3640
3641struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3642 using OpRewritePattern<PadOp>::OpRewritePattern;
3643
3644 LogicalResult matchAndRewrite(PadOp padTensorOp,
3645 PatternRewriter &rewriter) const override {
3646 Value input = padTensorOp.getSource();
3647 if (!llvm::isa<RankedTensorType>(input.getType()))
3648 return failure();
3649 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3650 auto inputRank = inputDims.size();
3651
3652 auto oldResultType =
3653 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3654 if (!oldResultType)
3655 return failure();
3656
3657 auto outputDims = oldResultType.getShape();
3658
3659 // Extract the static info from the high and low operands.
3660 SmallVector<int64_t> constOperandsLow;
3661 SmallVector<Value> newLows;
3662 for (auto operand : padTensorOp.getLow()) {
3663 APSInt intOp;
3664 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3665 constOperandsLow.push_back(ShapedType::kDynamic);
3666 newLows.push_back(operand);
3667 continue;
3668 }
3669 constOperandsLow.push_back(intOp.getExtValue());
3670 }
3671 SmallVector<int64_t> constOperandsHigh;
3672 SmallVector<Value> newHighs;
3673 for (auto operand : padTensorOp.getHigh()) {
3674 APSInt intOp;
3675 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3676 constOperandsHigh.push_back(ShapedType::kDynamic);
3677 newHighs.push_back(operand);
3678 continue;
3679 }
3680 constOperandsHigh.push_back(intOp.getExtValue());
3681 }
3682
3683 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3684 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3685
3686 // Verify the op is well-formed.
3687 if (inputDims.size() != outputDims.size() ||
3688 inputDims.size() != constLow.size() ||
3689 inputDims.size() != constHigh.size())
3690 return failure();
3691
3692 auto lowCount = 0;
3693 auto highCount = 0;
3694 for (size_t i = 0; i < inputRank; i++) {
3695 if (constLow[i] == ShapedType::kDynamic)
3696 constLow[i] = constOperandsLow[lowCount++];
3697 if (constHigh[i] == ShapedType::kDynamic)
3698 constHigh[i] = constOperandsHigh[highCount++];
3699 }
3700
3701 auto staticLow = ArrayRef<int64_t>(constLow);
3702 auto staticHigh = ArrayRef<int64_t>(constHigh);
3703
3704 // Calculate the output sizes with the static information.
3705 SmallVector<int64_t> newOutDims;
3706 for (size_t i = 0; i < inputRank; i++) {
3707 if (outputDims[i] == ShapedType::kDynamic) {
3708 newOutDims.push_back(
3709 (staticLow[i] == ShapedType::kDynamic ||
3710 staticHigh[i] == ShapedType::kDynamic ||
3711 inputDims[i] == ShapedType::kDynamic
3712 ? ShapedType::kDynamic
3713 : inputDims[i] + staticLow[i] + staticHigh[i]));
3714 } else {
3715 newOutDims.push_back(outputDims[i]);
3716 }
3717 }
3718
3719 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3720 llvm::all_of(newOutDims,
3721 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3722 return failure();
3723
3724 // Rewrite the op using the new static type.
3725 auto newResultType = RankedTensorType::get(
3726 newOutDims, padTensorOp.getType().getElementType());
3727 auto newOp = PadOp::create(
3728 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3729 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3730 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3731
3732 IRMapping mapper;
3733 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3734 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3735 newOp);
3736
3737 return success();
3738 }
3739};
3740
3741/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3742///
3743/// Example:
3744///
3745/// ```mlir
3746/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3747/// tensor.yield %val
3748/// } : tensor<1x2xf32> to tensor<2x5xf32>
3749/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3750/// tensor.yield %val
3751/// } : tensor<1x5xf32> to tensor<5x7xf32>
3752/// ```
3753///
3754/// folds into:
3755///
3756/// ```mlir
3757/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3758/// tensor.yield %val
3759/// } : tensor<1x2xf32> to tensor<5x7xf32>
3760/// ```
3761struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3762 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3763
3764 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3765 PatternRewriter &rewriter) const override {
3766 if (padOp.getNofold()) {
3767 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3768 }
3769
3770 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3771 if (!producerPad || producerPad.getNofold()) {
3772 return rewriter.notifyMatchFailure(
3773 padOp, "producer is not a foldable tensor.pad op");
3774 }
3775
3776 // Fail if the tensor::PadOps padding values do not match.
3777 Value consumerPadValue = padOp.getConstantPaddingValue();
3778 Value producerPadValue = producerPad.getConstantPaddingValue();
3779 if (!consumerPadValue || !producerPadValue ||
3780 consumerPadValue != producerPadValue) {
3781 return rewriter.notifyMatchFailure(
3782 padOp,
3783 "cannot fold PadOps with different or non-constant padding values");
3784 }
3785
3786 Location loc = padOp.getLoc();
3787 AffineExpr d0, d1;
3788 bindDims(rewriter.getContext(), d0, d1);
3789
3790 // Combine the low/high paddings of the two tensor::PadOps.
3791 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3792 ArrayRef<OpFoldResult> producerPaddings) {
3793 SmallVector<OpFoldResult> sumPaddings;
3794 for (auto [consumerIndex, producerIndex] :
3795 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3796 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3797 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3798 }
3799 return sumPaddings;
3800 };
3801
3802 SmallVector<OpFoldResult> newHighPad =
3803 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3804 SmallVector<OpFoldResult> newLowPad =
3805 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3806
3807 auto newPadOp = tensor::PadOp::create(
3808 rewriter, padOp.getLoc(), padOp.getResultType(),
3809 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3810 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3811 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3812 newPadOp.getRegion().begin());
3813 rewriter.replaceOp(padOp, newPadOp.getResult());
3814 return success();
3815 }
3816};
3817
3818} // namespace
3819
3820LogicalResult
3821PadOp::reifyResultShapes(OpBuilder &b,
3822 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3823 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3824 SmallVector<OpFoldResult> lp = getMixedLowPad();
3825 SmallVector<OpFoldResult> hp = getMixedHighPad();
3826 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3827 if (!getType().isDynamicDim(i)) {
3828 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3829 continue;
3830 }
3831 Location loc = getLoc();
3832 Value dim = b.createOrFold<tensor::DimOp>(
3833 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3834
3835 AffineExpr d0, d1, d2;
3836 bindDims(b.getContext(), d0, d1, d2);
3837 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3838 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3839 }
3840 return success();
3841}
3842
3843void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3844 MLIRContext *context) {
3845 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3846 FoldOrthogonalPaddings, FoldStaticPadding,
3847 FoldConsecutiveConstantPadding>(context);
3848}
3849
3850/// Return the padding value of the PadOp if it constant. In this context,
3851/// "constant" means an actual constant or "defined outside of the block".
3852///
3853/// Values are considered constant in three cases:
3854/// - A ConstantLike value.
3855/// - A basic block argument from a different block.
3856/// - A value defined outside of the block.
3857///
3858/// If the padding value is not constant, an empty Value is returned.
3859Value PadOp::getConstantPaddingValue() {
3860 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3861 if (!yieldOp)
3862 return {};
3863 Value padValue = yieldOp.getValue();
3864 // Check if yield value is a constant.
3865 if (matchPattern(padValue, m_Constant()))
3866 return padValue;
3867 // Check if yield value is defined inside the PadOp block.
3868 if (padValue.getParentBlock() == &getRegion().front())
3869 return {};
3870 // Else: Yield value defined outside of the PadOp block.
3871 return padValue;
3872}
3873
3874OpFoldResult PadOp::fold(FoldAdaptor) {
3875 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3876 !getNofold())
3877 return getSource();
3878 return {};
3879}
3880
3881//===----------------------------------------------------------------------===//
3882// ParallelInsertSliceOp
3883//===----------------------------------------------------------------------===//
3884
3885OpResult ParallelInsertSliceOp::getTiedOpResult() {
3886 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3887 for (const auto &it :
3888 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3889 Operation &nextOp = it.value();
3890 if (&nextOp == getOperation())
3891 return parallelCombiningParent.getParentResult(it.index());
3892 }
3893 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3894}
3895
3896// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3897void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3898 Value source, Value dest,
3899 ArrayRef<OpFoldResult> offsets,
3900 ArrayRef<OpFoldResult> sizes,
3901 ArrayRef<OpFoldResult> strides,
3902 ArrayRef<NamedAttribute> attrs) {
3903 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3904 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3905 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3906 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3907 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3908 result.addAttributes(attrs);
3909 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3910 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3911 b.getDenseI64ArrayAttr(staticSizes),
3912 b.getDenseI64ArrayAttr(staticStrides));
3913}
3914
3915/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3916/// packed into a Range vector.
3917void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3918 Value source, Value dest,
3919 ArrayRef<Range> ranges,
3920 ArrayRef<NamedAttribute> attrs) {
3921 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3922 build(b, result, source, dest, offsets, sizes, strides, attrs);
3923}
3924
3925// Build a ParallelInsertSliceOp with dynamic entries.
3926void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3927 Value source, Value dest, ValueRange offsets,
3928 ValueRange sizes, ValueRange strides,
3929 ArrayRef<NamedAttribute> attrs) {
3930 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3931 offsets, [](Value v) -> OpFoldResult { return v; });
3932 SmallVector<OpFoldResult> sizeValues =
3933 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3934 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3935 strides, [](Value v) -> OpFoldResult { return v; });
3936 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3937}
3938
3939// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3940// to 0, strides set to 1 and inferred result type.
3941void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3942 Value dest, ArrayRef<OpFoldResult> sizes,
3943 ArrayRef<NamedAttribute> attrs) {
3944 Attribute zeroIdxAttr = b.getIndexAttr(0);
3945 Attribute oneIdxAttr = b.getIndexAttr(1);
3946 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3947 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3948 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3949}
3950
3951LogicalResult ParallelInsertSliceOp::verify() {
3952 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3953 return this->emitError("expected InParallelOpInterface parent, got:")
3954 << *(getOperation()->getParentOp());
3955
3956 // Verify result type against inferred type.
3957 RankedTensorType expectedType;
3959 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3960 getStaticSizes(), getStaticStrides(), &expectedType);
3962 return produceSliceErrorMsg(result, *this, expectedType);
3963
3964 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3965 // to the destination tensor.
3966 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3967 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3968 getStaticStrides(), /*generateErrorMessage=*/true);
3969 if (!boundsResult.isValid)
3970 return getOperation()->emitError(boundsResult.errorMessage);
3971
3972 return success();
3973}
3974
3975void ParallelInsertSliceOp::getCanonicalizationPatterns(
3976 RewritePatternSet &results, MLIRContext *context) {
3977 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3978 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3979 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3980}
3981
3982llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3983 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3984}
3985
3986// ParallelCombiningOpInterface implementation.
3987MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3988 return getDestMutable();
3989}
3990
3991Operation *ParallelInsertSliceOp::getIteratingParent() {
3992 // Return the parent InParallelOpInterface's parent.
3993 if (auto combiningOp =
3994 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3995 return combiningOp->getParentOp();
3996 return nullptr;
3997}
3998
3999//===----------------------------------------------------------------------===//
4000// ScatterOp
4001//===----------------------------------------------------------------------===//
4002
4003void ScatterOp::getAsmResultNames(
4004 function_ref<void(Value, StringRef)> setNameFn) {
4005 setNameFn(getResult(), "scatter");
4006}
4007
4008LogicalResult ScatterOp::verify() {
4009 int64_t destRank = getDestType().getRank();
4010 ArrayRef<int64_t> scatterDims = getScatterDims();
4011 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
4012 getIndicesType().getShape(), destRank,
4013 "scatter", "dest")))
4014 return failure();
4015
4016 if (!getUnique())
4017 return emitOpError("requires 'unique' attribute to be set");
4018 // TODO: we could also check statically that there are fewer leading index
4019 // tensor dims than the dest dims. If this is not the case, the unique
4020 // attribute cannot be true.
4021
4022 // Use the GatherOp::inferResultType on the `dest` type and verify the
4023 // expected type matches the source type.
4024 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4025 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
4026 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4027 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
4028 if (getSourceType() != expectedSourceType &&
4029 getSourceType() != expectedRankReducedSourceType) {
4030 return emitOpError("source type "
4031 "mismatch: "
4032 "expected ")
4033 << expectedSourceType << " or its rank-reduced variant "
4034 << expectedRankReducedSourceType << " (got: " << getSourceType()
4035 << ")";
4036 }
4037
4038 return success();
4039}
4040
4041//===----------------------------------------------------------------------===//
4042// SplatOp
4043//===----------------------------------------------------------------------===//
4044
4045void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4046 Type aggregateType, ValueRange dynamicSizes) {
4047 build(builder, result, aggregateType, element, dynamicSizes);
4048}
4049
4050void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4051 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4052 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4053 build(builder, result, aggregateType, element, dynamicSizes);
4054}
4055
4056void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4057 ArrayRef<OpFoldResult> sizes) {
4058 SmallVector<int64_t> staticShape;
4059 SmallVector<Value> dynamicSizes;
4060 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4061 build(builder, result, element, staticShape, dynamicSizes);
4062}
4063
4064void SplatOp::getAsmResultNames(
4065 function_ref<void(Value, StringRef)> setNameFn) {
4066 setNameFn(getResult(), "splat");
4067}
4068
4069LogicalResult SplatOp::verify() {
4070 return verifyDynamicDimensionCount(getOperation(), getType(),
4071 getDynamicSizes());
4072}
4073
4074LogicalResult
4075SplatOp::reifyResultShapes(OpBuilder &builder,
4076 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4077 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4078 unsigned ctr = 0;
4079 for (int64_t i = 0; i < getType().getRank(); ++i) {
4080 if (getType().isDynamicDim(i)) {
4081 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4082 } else {
4083 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4084 }
4085 }
4086 return success();
4087}
4088
4089OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4090 auto constOperand = adaptor.getInput();
4091 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4092 return {};
4093
4094 // Do not fold if the splat is not statically shaped
4095 if (!getType().hasStaticShape())
4096 return {};
4097
4098 // SplatElementsAttr::get treats single value for second arg as being a
4099 // splat.
4100 return SplatElementsAttr::get(getType(), {constOperand});
4101}
4102
4103//===----------------------------------------------------------------------===//
4104// Common Canonicalizers and Folders.
4105//===----------------------------------------------------------------------===//
4106static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4107 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4108 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4109 // might need special handling of attached regions.
4110 if (isa<InsertSliceOp>(op.getOperation()) ||
4111 isa<LoopLikeOpInterface>(op.getOperation()))
4112 return false;
4113
4115}
4116
4117/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4118/// the `tensor.cast` has source that is more static than the consuming op.
4119///
4120/// Example:
4121/// ```mlir
4122/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4123/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4124/// ```
4125///
4126/// folds into:
4127///
4128/// ```mlir
4129/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4130/// ```
4131/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4132/// can add the pattern to their canonicalizers.
4134 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4136 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4137
4138 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4139 PatternRewriter &rewriter) const override {
4140
4141 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4142 // for that instead.
4143 if (!foldTensorCastPrecondition(op) ||
4144 isa<linalg::RelayoutOpInterface>(*op))
4145 return failure();
4146
4147 SmallVector<Type> newResultTypes(op->getResultTypes());
4148 SmallVector<Value> newOperands =
4149 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4150
4151 // Clone op
4152 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4153
4154 SmallVector<Value, 4> replacements;
4155 replacements.reserve(newOp->getNumResults());
4156 for (auto [oldResult, newResult] :
4157 llvm::zip(op->getResults(), newOp->getResults())) {
4158 if (newResult.getType() != oldResult.getType()) {
4159 replacements.push_back(tensor::CastOp::create(
4160 rewriter, op->getLoc(), oldResult.getType(), newResult));
4161 } else {
4162 replacements.push_back(newResult);
4163 }
4164 }
4165 rewriter.replaceOp(op, replacements);
4166
4167 return success();
4168 }
4169};
4170
4171//===----------------------------------------------------------------------===//
4172// TensorDialect
4173//===----------------------------------------------------------------------===//
4174
4175void TensorDialect::getCanonicalizationPatterns(
4176 RewritePatternSet &results) const {
4177 results.add<FoldTensorCastProducerOp>(getContext());
4178}
4179
4180//===----------------------------------------------------------------------===//
4181// TableGen'd op method definitions
4182//===----------------------------------------------------------------------===//
4183
4184#define GET_OP_CLASSES
4185#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:382
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:77
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:91
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.