MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/SmallVectorExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/Support/Casting.h"
40#include "llvm/Support/MathExtras.h"
41#include <optional>
42
43using namespace mlir;
44using namespace mlir::tensor;
45
46/// Materialize a single constant operation from a given attribute value with
47/// the desired resultant type.
48Operation *TensorDialect::materializeConstant(OpBuilder &builder,
49 Attribute value, Type type,
50 Location loc) {
51 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
52 return op;
53 if (complex::ConstantOp::isBuildableWith(value, type))
54 return complex::ConstantOp::create(builder, loc, type,
55 llvm::cast<ArrayAttr>(value));
56 return nullptr;
57}
58
60 int64_t dim) {
61 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
62 if (tensorType.isDynamicDim(dim))
63 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
64
65 return builder.getIndexAttr(tensorType.getDimSize(dim));
66}
67
69 Location loc, Value value) {
70 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
72 for (int64_t i = 0; i < tensorType.getRank(); ++i)
73 result.push_back(getMixedSize(builder, loc, value, i));
74 return result;
75}
76
78 OpResult opResult) {
79 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
80 assert(tensorType && "expected tensor type");
81
82 // If the op has a destination, it implements DestinationStyleOpInterface and
83 // we can query the destination operand from that interface.
84 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
85 if (destOp)
86 return destOp.getTiedOpOperand(opResult)->get();
87
88 // Otherwise, create a new destination tensor with the same shape.
90 b.setInsertionPoint(opResult.getDefiningOp());
91
92 // Compute sizes.
94 if (!tensorType.hasStaticShape()) {
95 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
96 ReifiedRankedShapedTypeDims reifiedShapes;
97 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
98 return failure();
99 mixedSizes = reifiedShapes[opResult.getResultNumber()];
100 } else {
101 // Static shape: Take static sizes directly.
102 for (int64_t sz : tensorType.getShape())
103 mixedSizes.push_back(b.getIndexAttr(sz));
104 }
105
106 // Create empty tensor.
107 Value emptyTensor =
108 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
109 return emptyTensor;
110}
111
113 Operation *op,
115 for (OpResult opResult : op->getResults()) {
116 if (llvm::isa<TensorType>(opResult.getType())) {
117 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
118 if (failed(destination))
119 return failure();
120 result.push_back(*destination);
121 }
122 }
123 return success();
124}
125
127 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
128 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
129 return rtp1.getShape() == rtp2.getShape() &&
130 rtp1.getElementType() == rtp2.getElementType();
131 return false;
132 }
133 return tp1 == tp2; // default implementation
134}
135
136/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
137/// rank-extending tensor.insert_slice op.
138static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
139 ArrayRef<OpFoldResult> mixedSizes) {
140 llvm::SmallBitVector droppedDims(mixedSizes.size());
141 int64_t shapePos = reducedShape.size() - 1;
142
143 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
144 size_t idx = mixedSizes.size() - size.index() - 1;
145 // Rank-reduced dims must have a static unit dimension.
146 bool isStaticUnitSize =
147 isa<Attribute>(size.value()) &&
148 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
149
150 if (shapePos < 0) {
151 // There are no more dims in the reduced shape. All remaining sizes must
152 // be rank-reduced dims.
153 assert(isStaticUnitSize && "expected unit dim");
154 droppedDims.set(idx);
155 continue;
156 }
157
158 // Dim is preserved if the size is not a static 1.
159 if (!isStaticUnitSize) {
160 --shapePos;
161 continue;
162 }
163
164 // Dim is preserved if the reduced shape dim is also 1.
165 if (reducedShape[shapePos] == 1) {
166 --shapePos;
167 continue;
168 }
169
170 // Otherwise: Dim is dropped.
171 droppedDims.set(idx);
172 }
173
174 assert(shapePos < 0 && "dimension mismatch");
175 return droppedDims;
176}
177
178/// Given a ranked tensor type and a range of values that defines its dynamic
179/// dimension sizes, turn all dynamic sizes that have a constant value into
180/// static dimension sizes.
181static RankedTensorType
182foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
183 SmallVector<Value> &foldedDynamicSizes) {
184 SmallVector<int64_t> staticShape(type.getShape());
185 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
186 "incorrect number of dynamic sizes");
187
188 // Compute new static and dynamic sizes.
189 unsigned ctr = 0;
190 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
191 if (type.isDynamicDim(i)) {
192 Value dynamicSize = dynamicSizes[ctr++];
193 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
194 if (cst.has_value()) {
195 // Dynamic size must be non-negative.
196 if (cst.value() < 0) {
197 foldedDynamicSizes.push_back(dynamicSize);
198 continue;
199 }
200 staticShape[i] = *cst;
201 } else {
202 foldedDynamicSizes.push_back(dynamicSize);
203 }
204 }
205 }
206
207 return RankedTensorType::get(staticShape, type.getElementType(),
208 type.getEncoding());
209}
210
211//===----------------------------------------------------------------------===//
212// BitcastOp
213//===----------------------------------------------------------------------===//
214
215bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
216 if (inputs.size() != 1 || outputs.size() != 1)
217 return false;
218 Type a = inputs.front(), b = outputs.front();
219 auto aT = dyn_cast<TensorType>(a);
220 auto bT = dyn_cast<TensorType>(b);
221 if (!aT || !bT)
222 return false;
223
224 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
225 return false;
226
227 return succeeded(verifyCompatibleShape(aT, bT));
228}
229
230namespace {
231
232/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
233/// operation.
234struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
235 using OpRewritePattern<BitcastOp>::OpRewritePattern;
236
237 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
238 PatternRewriter &rewriter) const final {
239 auto tensorBitcastOperand =
240 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
241 if (!tensorBitcastOperand)
242 return failure();
243
244 auto resultType = cast<TensorType>(tensorBitcast.getType());
245 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
246 tensorBitcastOperand.getOperand());
247 return success();
248 }
249};
250
251} // namespace
252
253void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
254 MLIRContext *context) {
255 results.add<ChainedTensorBitcast>(context);
256}
257
258//===----------------------------------------------------------------------===//
259// CastOp
260//===----------------------------------------------------------------------===//
261
262void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
263 setNameFn(getResult(), "cast");
264}
265
266/// Returns true if `target` is a ranked tensor type that preserves static
267/// information available in the `source` ranked tensor type.
269 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
270 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271
272 // Requires RankedTensorType.
273 if (!sourceType || !targetType)
274 return false;
275
276 // Requires same elemental type.
277 if (sourceType.getElementType() != targetType.getElementType())
278 return false;
279
280 // Requires same rank.
281 if (sourceType.getRank() != targetType.getRank())
282 return false;
283
284 // Requires same encoding.
285 if (sourceType.getEncoding() != targetType.getEncoding())
286 return false;
287
288 // If cast is towards more static sizes along any dimension, don't fold.
289 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
290 if (ShapedType::isStatic(std::get<0>(t)) &&
291 ShapedType::isDynamic(std::get<1>(t)))
292 return false;
293 }
294
295 return true;
296}
297
298/// Determines whether tensor::CastOp casts to a more dynamic version of the
299/// source tensor. This is useful to fold a tensor.cast into a consuming op and
300/// implement canonicalization patterns for ops in different dialects that may
301/// consume the results of tensor.cast operations. Such foldable tensor.cast
302/// operations are typically inserted as `slice` ops and are canonicalized,
303/// to preserve the type compatibility of their uses.
304///
305/// Returns true when all conditions are met:
306/// 1. source and result are ranked tensors with same element type and rank.
307/// 2. the tensor type has more static information than the result
308///
309/// Example:
310/// ```mlir
311/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
312/// %2 = consumer %1 ... : tensor<?x?xf32> ...
313/// ```
314///
315/// folds into:
316///
317/// ```mlir
318/// %2 = consumer %0 ... : tensor<8x16xf32> ...
319/// ```
321 if (!castOp)
322 return false;
323
324 // Can fold if the source of cast has at least as much static information as
325 // its results.
326 return preservesStaticInformation(castOp.getType(),
327 castOp.getSource().getType());
328}
329
330/// Determines whether the tensor::CastOp casts to a more static version of the
331/// source tensor. This is useful to fold into a producing op and implement
332/// canonicalization patterns with the `tensor.cast` op as the root, but
333/// producer being from different dialects. Returns true when all conditions are
334/// met:
335/// 1. source and result and ranked tensors with same element type and rank.
336/// 2. the result type has more static information than the source.
337///
338/// Example:
339/// ```mlir
340/// %1 = producer ... : tensor<?x?xf32>
341/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
342/// ```
343///
344/// can be canonicalized to :
345///
346/// ```mlir
347/// %2 = producer ... : tensor<8x16xf32>
348/// ```
349/// Not all ops might be canonicalizable this way, but for those that can be,
350/// this method provides a check that it is worth doing the canonicalization.
352 if (!castOp)
353 return false;
354 return preservesStaticInformation(castOp.getSource().getType(),
355 castOp.getType());
356}
357
359 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
360 if (llvm::isa<BlockArgument>(opOperand.get()))
361 return false;
362 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363 return castOp && canFoldIntoConsumerOp(castOp);
364 });
365}
366
368 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
369 SmallVector<Value> newOperands;
370 newOperands.reserve(op->getNumOperands());
371
372 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
373
374 // Assumes that the result has dpsInits followed by nonDpsInits.
375 int64_t dpsInitIdx = 0;
376 for (OpOperand &opOperand : op->getOpOperands()) {
377 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
378 bool fold = canFoldIntoConsumerOp(tensorCastOp);
379 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380 if (op.isDpsInit(&opOperand) &&
381 !llvm::isa<MemRefType>(newOperands.back().getType()))
382 newResTy[dpsInitIdx++] = newOperands.back().getType();
383 }
384 return newOperands;
385}
386
387/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
388/// that can be folded.
390 bool folded = false;
391 for (OpOperand &operand : op->getOpOperands()) {
392 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
393 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
394 operand.set(castOp.getOperand());
395 folded = true;
396 }
397 }
398 return success(folded);
399}
400
401bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
402 if (inputs.size() != 1 || outputs.size() != 1)
403 return false;
404 Type a = inputs.front(), b = outputs.front();
405 auto aT = llvm::dyn_cast<TensorType>(a);
406 auto bT = llvm::dyn_cast<TensorType>(b);
407 if (!aT || !bT)
408 return false;
409
410 if (aT.getElementType() != bT.getElementType())
411 return false;
412
413 return succeeded(verifyCompatibleShape(aT, bT));
414}
415
416/// Compute a TensorType that has the joined shape knowledge of the two
417/// given TensorTypes. The element types need to match.
419 assert(one.getElementType() == two.getElementType());
420
421 if (!one.hasRank())
422 return two;
423 if (!two.hasRank())
424 return one;
425
426 int64_t rank = one.getRank();
427 if (rank != two.getRank())
428 return {};
429
431 join.reserve(rank);
432 for (int64_t i = 0; i < rank; ++i) {
433 if (one.isDynamicDim(i)) {
434 join.push_back(two.getDimSize(i));
435 continue;
436 }
437 if (two.isDynamicDim(i)) {
438 join.push_back(one.getDimSize(i));
439 continue;
440 }
441 if (one.getDimSize(i) != two.getDimSize(i))
442 return {};
443 join.push_back(one.getDimSize(i));
444 }
445 return RankedTensorType::get(join, one.getElementType());
446}
447
448namespace {
449
450/// Replaces chains of two tensor.cast operations by a single tensor.cast
451/// operation if doing so does not remove runtime constraints.
452struct ChainedTensorCast : public OpRewritePattern<CastOp> {
453 using OpRewritePattern<CastOp>::OpRewritePattern;
454
455 LogicalResult matchAndRewrite(CastOp tensorCast,
456 PatternRewriter &rewriter) const final {
457 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
458
459 if (!tensorCastOperand)
460 return failure();
461
462 auto sourceType =
463 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
466
467 // We can remove the intermediate cast if joining all three produces the
468 // same result as just joining the source and result shapes.
469 auto firstJoin =
470 joinShapes(joinShapes(sourceType, intermediateType), resultType);
471
472 // The join might not exist if the cast sequence would fail at runtime.
473 if (!firstJoin)
474 return failure();
475
476 // The newJoin always exists if the above join exists, it might just contain
477 // less information. If so, we cannot drop the intermediate cast, as doing
478 // so would remove runtime checks.
479 auto newJoin = joinShapes(sourceType, resultType);
480 if (firstJoin != newJoin)
481 return failure();
482
483 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484 tensorCastOperand.getOperand());
485 return success();
486 }
487};
488
489/// Fold tensor.cast into tesor.extract_slice producer.
490/// Example:
491/// ```
492/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
493/// tensor<128x512xf32> to tensor<?x512xf32>
494/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
495/// ```
496/// ->
497/// ```
498/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
499/// tensor<128x512xf32> to tensor<16x512xf32>
500/// ```
501struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
502 using OpRewritePattern<CastOp>::OpRewritePattern;
503
504 LogicalResult matchAndRewrite(CastOp tensorCast,
505 PatternRewriter &rewriter) const final {
506 auto extractOperand =
507 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508
509 // Cannot fold cast to unranked tensor.
510 auto rankedResultType =
511 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512 if (!rankedResultType)
513 return failure();
514
515 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
516 rankedResultType.getShape() ==
517 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
518 .getShape())
519 return failure();
520
521 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
522 auto dimMask = computeRankReductionMask(
523 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
524 size_t dimIndex = 0;
525 for (size_t i = 0, e = sizes.size(); i < e; i++) {
526 if (dimMask && dimMask->count(i))
527 continue;
528 int64_t dim = rankedResultType.getShape()[dimIndex++];
529 if (ShapedType::isDynamic(dim))
530 continue;
531 sizes[i] = rewriter.getIndexAttr(dim);
532 }
533
534 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535 tensorCast, rankedResultType, extractOperand.getSource(),
536 extractOperand.getMixedOffsets(), sizes,
537 extractOperand.getMixedStrides());
538 return success();
539 }
540};
541
542} // namespace
543
544void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
545 MLIRContext *context) {
546 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
547}
548
549//===----------------------------------------------------------------------===//
550// ConcatOp
551//===----------------------------------------------------------------------===//
552
553RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
554 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
555 auto tensorTypes =
556 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
557 int64_t concatRank = tensorTypes[0].getRank();
558
559 // The concatenation dim must be in the range [0, rank).
560 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
561
562 SmallVector<int64_t> sizes(concatRank);
563 for (int64_t i = 0, e = concatRank; i < e; ++i) {
564 if (i == dim)
565 continue;
566 SaturatedInteger size;
567 for (auto tensorType : tensorTypes)
568 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
569 sizes[i] = size.asInteger();
570 }
571 auto concatSize = SaturatedInteger::wrap(0);
572 for (auto tensorType : tensorTypes)
573 concatSize =
574 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
577}
578
579void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
580 ValueRange inputs) {
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.getTypes());
583 assert(succeeded(resultType) && "failed to infer concatenation result type");
584 build(builder, result, *resultType, dim, inputs);
585}
586
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
589 return emitOpError("requires at least one input");
590
592 for (auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
594
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
599 }))
600 return emitOpError("rank of concatenated inputs must match result rank");
601
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
605 }))
606 return emitOpError("inputs and result element type must match");
607
608 int64_t dim = getDim();
609 if (dim >= resultRank)
610 return emitOpError("concatenation dim must be less than the tensor rank");
611
612 SmallVector<int64_t> sizes(resultRank);
613 for (int64_t i = 0, e = resultRank; i < e; ++i) {
614 if (i == dim)
615 continue;
616 SaturatedInteger size;
617 for (auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
619 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
620 if (failed(maybeSize))
621 return emitOpError("static concatenation size mismatch along ")
622 << "non-concatenated dimension " << i;
623 size = *maybeSize;
624 }
625 sizes[i] = size.asInteger();
626 }
627 auto concatSize = SaturatedInteger::wrap(0);
628 for (auto tensorType : inputTypes)
629 concatSize =
630 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
633 RankedTensorType::get(sizes, inputTypes[0].getElementType());
634
635 for (auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
640 return emitOpError("result type ")
641 << resultType << "does not match inferred shape "
642 << inferredResultType << " static sizes";
643 }
644
645 return success();
646}
647
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
651
653 inputShapes.reserve(numInputs);
654 SmallVector<OpFoldResult> concatOffsets;
655 concatOffsets.reserve(numInputs);
656 SmallVector<OpFoldResult> outputShape;
657
658 AffineExpr addExpr =
659 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
660 OpFoldResult zero = builder.getIndexAttr(0);
661 Location loc = getLoc();
662 for (auto [index, input] : llvm::enumerate(getInputs())) {
663 SmallVector<OpFoldResult> inputShape =
664 tensor::getMixedSizes(builder, input.getLoc(), input);
665 if (index == 0) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
668 } else {
669 concatOffsets.push_back(outputShape[concatDim]);
670 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
673 }
674 inputShapes.emplace_back(std::move(inputShape));
675 }
676
677 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
679
680 int64_t rank = getType().getRank();
681 OpFoldResult one = builder.getIndexAttr(1);
682 SmallVector<OpFoldResult> strides(rank, one);
683 SmallVector<OpFoldResult> offsets(rank, zero);
684 for (auto [index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[index];
686 auto insertSlice = tensor::InsertSliceOp::create(
687 builder, loc, input, replacement, offsets, inputShapes[index], strides);
688 replacement = insertSlice.getResult();
689 }
690 if (replacement.getType() != getType()) {
691 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
692 }
694}
695
696LogicalResult
697ConcatOp::reifyResultShapes(OpBuilder &builder,
698 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
699 ValueRange inputs = getInputs();
700 int64_t dim = getDim();
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
702
703 Value init = inputs[0];
704 int64_t rank = getType().getRank();
705
706 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
707
708 // Pre-populate the result sizes with as much static information as possible
709 // from the given result type, as well as the inferred result type, otherwise
710 // use the dim sizes from the first input.
711 for (int64_t i = 0; i < rank; ++i) {
712 if (i == dim)
713 continue;
714 if (!getType().isDynamicDim(i)) {
715 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
716 } else if (!inferredResultType.isDynamicDim(i)) {
717 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
718 builder, getLoc(),
719 builder.getIndexAttr(inferredResultType.getDimSize(i)));
720 } else {
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
723 }
724 }
725
726 if (getType().isDynamicDim(dim)) {
727 // Take the sum of the input sizes along the concatenated dim.
728 AffineExpr sum = builder.getAffineDimExpr(0);
730 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
731 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732 sum = sum + builder.getAffineDimExpr(idx + 1);
733 sizes.push_back(
734 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
735 }
736 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
737 builder, getLoc(),
738 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
739 } else {
740 // If the result shape is static along the concatenated dim, use the static
741 // shape.
742 reifiedReturnShapes[0][dim] =
743 builder.getIndexAttr(getType().getDimSize(dim));
744 }
745 return success();
746}
747
748void ConcatOp::getAsmResultNames(
749 function_ref<void(Value, StringRef)> setNameFn) {
750 setNameFn(getResult(), "concat");
751}
752
753OpFoldResult ConcatOp::fold(FoldAdaptor) {
754 ValueRange inputs = getInputs();
755 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
756 return inputs[0];
757 return {};
758}
759
760namespace {
761/// Fold a concat op with a single input to a cast.
762struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
764
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter) const override {
767 if (concatOp.getInputs().size() != 1)
768 return failure();
769 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
770 concatOp.getInputs()[0]);
771 return success();
772 }
773};
774
775/// Propagate static shapes into the operands of a `tensor.concat`.
776///
777/// `tensor.concat` requires every operand to match on all dimensions except the
778/// concatenation dimension. If one operand is already static in those
779/// dimensions, the other operands may safely be refined to that same static
780/// shape.
781///
782/// Example:
783///
784/// ```mlir
785/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
786/// tensor<?x12xi32>
787/// ```
788/// ->
789/// ```mlir
790/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
791/// %2 = tensor.concat dim(0) %0, %cast :
792/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
793/// ```
794struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
796
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter) const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802
803 // Find operands for which a more static shape can be inferred.
804 LogicalResult matched = failure();
805 // Inferred operand shapes are identical in every dimension except the
806 // concatenation dimension.
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
810 // Compute inferred type for operand.
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
815
816 // Check if inferred type is more static.
817 if (!preservesStaticInformation(inferredOperandType, operandType)) {
818 matched = success();
819
820 // Use refined operand type and create cast from original operand.
821 auto castOp =
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
824 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
826 });
827 }
828 }
829
830 return matched;
831 }
832};
833
834// Ensure `tensor.concat`'s result type is at least as static as can be inferred
835// from its operand types.
836///
837/// Example:
838/// ```mlir
839/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
840/// tensor<?x?xi32>
841/// ```
842/// ->
843/// ```mlir
844/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
845/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
846/// tensor<?x?xi32>
847/// ```
848struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
850
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter) const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
856
857 // The result type should be at least as static as inferred result type.
858 if (preservesStaticInformation(inferredResultType,
859 concatOp.getResultType())) {
860 return failure();
861 }
862
863 auto newConcatOp =
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
866 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
867 newConcatOp);
868
869 return success();
870 }
871};
872} // namespace
873
874void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
875 MLIRContext *context) {
876 results
877 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
878 context);
879}
880
881//===----------------------------------------------------------------------===//
882// DimOp
883//===----------------------------------------------------------------------===//
884
885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886 setNameFn(getResult(), "dim");
887}
888
889void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890 int64_t index) {
891 auto loc = result.location;
892 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893 build(builder, result, source, indexValue);
894}
895
896std::optional<int64_t> DimOp::getConstantIndex() {
898}
899
900Speculation::Speculatability DimOp::getSpeculatability() {
901 auto constantIndex = getConstantIndex();
902 if (!constantIndex)
904
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
906 if (!rankedSourceType)
908
909 if (rankedSourceType.getRank() <= constantIndex)
911
913}
914
915void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916 SetIntLatticeFn setResultRange) {
917 setResultRange(getResult(),
918 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919}
920
921OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
922 // All forms of folding require a known index.
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
924 if (!index)
925 return {};
926
927 // Folding for unranked types (UnrankedTensorType) is not supported.
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
929 if (!tensorType)
930 return {};
931
932 // Out of bound indices produce undefined behavior but are still valid IR.
933 // Don't choke on them.
934 int64_t indexVal = index.getInt();
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
936 return {};
937
938 // Fold if the shape extent along the given index is known.
939 if (!tensorType.isDynamicDim(index.getInt())) {
940 Builder builder(getContext());
941 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
942 }
943
944 Operation *definingOp = getSource().getDefiningOp();
945
946 // Fold dim to the operand of tensor.generate.
947 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
948 auto resultType =
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950 // The case where the type encodes the size of the dimension is handled
951 // above.
952 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
953
954 // Find the operand of the fromElements that corresponds to this index.
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (auto dim : resultType.getShape().take_front(index.getInt()))
957 if (ShapedType::isDynamic(dim))
958 dynExtents++;
959
960 return Value{*dynExtents};
961 }
962
963 // The size at the given index is now known to be a dynamic size.
964 unsigned unsignedIndex = index.getValue().getZExtValue();
965
966 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
968 // `resolve-shaped-type-result-dims` pass.
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
972 }
973 }
974
975 // dim(cast) -> dim
976 if (succeeded(foldTensorCast(*this)))
977 return getResult();
978
979 return {};
980}
981
982namespace {
983/// Fold dim of a cast into the dim of the source of the tensor cast.
984struct DimOfCastOp : public OpRewritePattern<DimOp> {
985 using OpRewritePattern<DimOp>::OpRewritePattern;
986
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter) const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990 if (!castOp)
991 return failure();
992 Value newSource = castOp.getOperand();
993 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
994 return success();
995 }
996};
997
998/// Fold dim of a destination passing style op into the dim of the corresponding
999/// init.
1000struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter) const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1007 if (!destOp)
1008 return failure();
1009
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012
1013 rewriter.modifyOpInPlace(
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1015 return success();
1016 }
1017};
1018
1019/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1020/// operand.
1021struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1023
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter) const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1027
1028 if (!reshape)
1029 return failure();
1030
1031 // Since tensors are immutable we don't need to worry about where to place
1032 // the extract call
1033 rewriter.setInsertionPointAfter(dim);
1034 Location loc = dim.getLoc();
1035 Value extract =
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.getType() != dim.getType())
1038 extract =
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1040 rewriter.replaceOp(dim, extract);
1041 return success();
1042 }
1043};
1044} // namespace
1045
1046void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1047 MLIRContext *context) {
1048 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// EmptyOp
1053//===----------------------------------------------------------------------===//
1054
1055void EmptyOp::build(OpBuilder &builder, OperationState &result,
1056 ArrayRef<int64_t> staticShape, Type elementType,
1057 Attribute encoding) {
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1060 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1061}
1062
1063void EmptyOp::build(OpBuilder &builder, OperationState &result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder, result, tensorType, dynamicSizes);
1068}
1069
1070void EmptyOp::build(OpBuilder &builder, OperationState &result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1075 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1076 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1077}
1078
1079LogicalResult EmptyOp::verify() {
1080 return verifyDynamicDimensionCount(getOperation(), getType(),
1081 getDynamicSizes());
1082}
1083
1084LogicalResult
1085EmptyOp::reifyResultShapes(OpBuilder &builder,
1086 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1087 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1088 unsigned ctr = 0;
1089 for (int64_t i = 0; i < getType().getRank(); ++i) {
1090 if (getType().isDynamicDim(i)) {
1091 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1092 } else {
1093 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1094 }
1095 }
1096 return success();
1097}
1098
1099Value EmptyOp::getDynamicSize(unsigned idx) {
1100 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1101 unsigned ctr = 0;
1102 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1103 if (getType().isDynamicDim(i))
1104 ++ctr;
1105 return getDynamicSizes()[ctr];
1106}
1107
1108SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1109 SmallVector<OpFoldResult> result;
1110 unsigned ctr = 0;
1111 Builder b(getContext());
1112 for (int64_t dim : getType().getShape()) {
1113 if (ShapedType::isDynamic(dim)) {
1114 result.push_back(getDynamicSizes()[ctr++]);
1115 } else {
1116 result.push_back(b.getIndexAttr(dim));
1117 }
1118 }
1119 return result;
1120}
1121
1122namespace {
1123/// Change the type of the result of a `tensor.empty` by making the result
1124/// type statically sized along dimensions that in the original operation were
1125/// defined as dynamic, but the size was defined using a `constant` op. For
1126/// example
1127///
1128/// %c5 = arith.constant 5: index
1129/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1130///
1131/// to
1132///
1133/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1134struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1135 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1136
1137 LogicalResult matchAndRewrite(EmptyOp op,
1138 PatternRewriter &rewriter) const override {
1139 SmallVector<Value> foldedDynamicSizes;
1140 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1141 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1142
1143 // Stop here if no dynamic size was promoted to static.
1144 if (foldedTensorType == op.getType())
1145 return failure();
1146
1147 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1148 foldedDynamicSizes);
1149 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1150 return success();
1151 }
1152};
1153
1154struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1155 using OpRewritePattern<DimOp>::OpRewritePattern;
1156
1157 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1158 PatternRewriter &rewriter) const override {
1159 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1160 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1161 if (!emptyTensorOp || !maybeConstantIndex)
1162 return failure();
1163 auto emptyTensorType = emptyTensorOp.getType();
1164 if (*maybeConstantIndex < 0 ||
1165 *maybeConstantIndex >= emptyTensorType.getRank() ||
1166 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1167 return failure();
1168 rewriter.replaceOp(dimOp,
1169 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1170 return success();
1171 }
1172};
1173
1174/// Canonicalize
1175///
1176/// ```mlir
1177/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1178/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1179/// ```
1180///
1181/// into
1182///
1183/// ```mlir
1184/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1185/// ```
1186///
1187/// This assumes the input program is correct in terms of its shape. So it is
1188/// safe to assume that `%d0` is in fact 4.
1189struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1190 using OpRewritePattern<CastOp>::OpRewritePattern;
1191
1192 LogicalResult matchAndRewrite(CastOp castOp,
1193 PatternRewriter &rewriter) const override {
1194 if (!canFoldIntoProducerOp(castOp))
1195 return failure();
1196 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1197 if (!producer)
1198 return failure();
1199
1200 auto resultType =
1201 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1202 ArrayRef<int64_t> resultShape = resultType.getShape();
1203 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1204 SmallVector<OpFoldResult> newMixedSizes;
1205 newMixedSizes.reserve(currMixedSizes.size());
1206 assert(resultShape.size() == currMixedSizes.size() &&
1207 "mismatch in result shape and sizes of empty op");
1208 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1209 int64_t newDim = std::get<0>(it);
1210 OpFoldResult currDim = std::get<1>(it);
1211 // Case 1: The empty tensor dim is static. Check that the tensor cast
1212 // result dim matches.
1213 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1214 if (ShapedType::isDynamic(newDim) ||
1215 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1216 // Something is off, the cast result shape cannot be more dynamic
1217 // than the empty tensor result shape (enforced by
1218 // `canFoldIntoProducer`). Abort for now.
1219 return rewriter.notifyMatchFailure(
1220 producer, "mismatch in static value of shape of empty tensor "
1221 "result and cast result");
1222 }
1223 newMixedSizes.push_back(attr);
1224 continue;
1225 }
1226
1227 // Case 2 : The tensor cast shape is static, but empty tensor result
1228 // shape is dynamic.
1229 if (ShapedType::isStatic(newDim)) {
1230 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1231 continue;
1232 }
1233
1234 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1235 // shape is dynamic. Use the dynamic value from the empty tensor op.
1236 newMixedSizes.push_back(currDim);
1237 }
1238
1239 // TODO: Do not drop tensor encoding.
1240 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1241 resultType.getElementType());
1242 return success();
1243 }
1244};
1245
1246} // namespace
1247
1248void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1251 ReplaceEmptyTensorStaticShapeDims>(context);
1252}
1253
1254//===----------------------------------------------------------------------===//
1255// ExtractOp
1256//===----------------------------------------------------------------------===//
1257
1258namespace {
1259
1260/// Canonicalizes the pattern of the form
1261///
1262/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1263/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1264///
1265/// to
1266///
1267/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1268struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1269 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1270
1271 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1272 PatternRewriter &rewriter) const final {
1273 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1274 if (!tensorCast)
1275 return failure();
1276 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1277 return failure();
1278 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1279 extract, tensorCast.getSource(), extract.getIndices());
1280 return success();
1281 }
1282};
1283
1284/// Canonicalizes the pattern of the form
1285///
1286/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1287/// tensor<12xf64>
1288/// %extracted_element = tensor.extract %val[%c10] :
1289/// tensor<12xf64>
1290///
1291/// to
1292///
1293/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1294struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1295 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1296
1297 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1298 PatternRewriter &rewriter) const final {
1299 auto collapseOp =
1300 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1301 if (!collapseOp)
1302 return failure();
1303 if (!collapseOp.getSrcType().hasStaticShape())
1304 return failure();
1305
1306 auto sourceSizes = collapseOp.getSrcType().getShape();
1307
1308 SmallVector<Value> indices(extractOp.getIndices().begin(),
1309 extractOp.getIndices().end());
1310 SmallVector<Value> sourceIndices;
1311 for (auto [index, group] :
1312 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1313 assert(!group.empty() && "association indices groups cannot be empty");
1314 auto groupSize = group.size();
1315
1316 if (groupSize == 1) {
1317 sourceIndices.push_back(index);
1318 continue;
1319 }
1320
1321 SmallVector<int64_t> basis =
1322 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1323 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1324 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1325 llvm::append_range(sourceIndices, delinearize.getResults());
1326 }
1327 if (collapseOp.getReassociationIndices().empty()) {
1328 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1329 int64_t srcRank =
1330 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1332 rewriter, extractOp.getLoc(), zeroAffineMap,
1333 ArrayRef<OpFoldResult>{});
1334 for (int64_t i = 0; i < srcRank; i++) {
1335 sourceIndices.push_back(
1336 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1337 }
1338 }
1339
1340 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1341 extractOp, collapseOp.getSrc(), sourceIndices);
1342 return success();
1343 }
1344};
1345
1346} // namespace
1347
1348void ExtractOp::getAsmResultNames(
1349 function_ref<void(Value, StringRef)> setNameFn) {
1350 setNameFn(getResult(), "extracted");
1351}
1352
1353LogicalResult ExtractOp::verify() {
1354 // Verify the # indices match if we have a ranked type.
1355 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1356 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1357 return emitOpError("incorrect number of indices for extract_element");
1358 return success();
1359}
1360
1361/// If we have an ExtractOp consuming an InsertOp with the same
1362/// indices, we can return the InsertOp's scalar directly.
1363// TODO: This only checks the immediate producer; extend to go up the
1364// insert/extract chain if the slices are disjoint.
1365static Value foldExtractAfterInsert(ExtractOp extractOp) {
1366 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1367
1368 auto isSame = [](Value a, Value b) {
1370 };
1371 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1372 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1373 return insertOp.getScalar();
1374
1375 return {};
1376}
1377
1378OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1379 if (Attribute tensor = adaptor.getTensor()) {
1380 // If this is a splat elements attribute, simply return the value.
1381 // All of the elements of a splat attribute are the same.
1382 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1383 return splatTensor.getSplatValue<Attribute>();
1384
1385 // If this is a dense resource elements attribute, return.
1386 if (isa<DenseResourceElementsAttr>(tensor))
1387 return {};
1388 }
1389
1390 // Collect the constant indices into the tensor.
1391 SmallVector<uint64_t, 8> indices;
1392 for (Attribute indice : adaptor.getIndices()) {
1393 if (!indice || !llvm::isa<IntegerAttr>(indice))
1394 return {};
1395 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1396 }
1397
1398 // Fold extract(from_elements(...)).
1399 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1400 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1401 auto rank = tensorType.getRank();
1402 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1403 "rank mismatch");
1404 int flatIndex = 0;
1405 int stride = 1;
1406 for (int i = rank - 1; i >= 0; --i) {
1407 flatIndex += indices[i] * stride;
1408 stride *= tensorType.getDimSize(i);
1409 }
1410 // Prevent out of bounds accesses. This can happen in invalid code that
1411 // will never execute.
1412 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1413 flatIndex < 0)
1414 return {};
1415 return fromElementsOp.getElements()[flatIndex];
1416 }
1417
1418 // If this is an elements attribute, query the value at the given indices.
1419 if (Attribute tensor = adaptor.getTensor()) {
1420 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1421 if (elementsAttr && elementsAttr.isValidIndex(indices))
1422 return elementsAttr.getValues<Attribute>()[indices];
1423 }
1424
1425 if (Value result = foldExtractAfterInsert(*this))
1426 return result;
1427
1428 return {};
1429}
1430
1431void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1432 MLIRContext *context) {
1433 results.add<ExtractFromTensorCast>(context);
1434}
1435
1437 RewritePatternSet &patterns) {
1438 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// FromElementsOp
1443//===----------------------------------------------------------------------===//
1444
1445void FromElementsOp::getAsmResultNames(
1446 function_ref<void(Value, StringRef)> setNameFn) {
1447 setNameFn(getResult(), "from_elements");
1448}
1449
1450void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1451 ValueRange elements) {
1452 assert(!elements.empty() && "expected at least one element");
1453 Type resultType = RankedTensorType::get(
1454 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1455 build(builder, result, resultType, elements);
1456}
1457
1458OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1459 // DenseElementsAttr::get requires StringAttr for element types that are not
1460 // integer, index, float, or complex (e.g. vector types), but folded constants
1461 // won't be StringAttr instances. Only fold for element types directly
1462 // supported by DenseElementsAttr.
1463 Type eltType = getType().getElementType();
1464 if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
1465 return {};
1466 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1467 return DenseElementsAttr::get(getType(), adaptor.getElements());
1468 return {};
1469}
1470
1471namespace {
1472
1473// Pushes the index_casts that occur before extractions to after the extract.
1474// This minimizes type conversion in some cases and enables the extract
1475// canonicalizer. This changes:
1476//
1477// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1478// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1479//
1480// to the following:
1481//
1482// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1483// %cast = arith.index_cast %extract : i32 to index
1484//
1485// to just %element.
1486//
1487// Consider expanding this to a template and handle all tensor cast
1488// operations.
1489struct ExtractElementFromIndexCast
1490 : public OpRewritePattern<tensor::ExtractOp> {
1491 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1492
1493 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1494 PatternRewriter &rewriter) const final {
1495 Location loc = extract.getLoc();
1496 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1497 if (!indexCast)
1498 return failure();
1499
1500 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1501
1502 auto newExtract = tensor::ExtractOp::create(
1503 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1504
1505 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1506 newExtract);
1507
1508 return success();
1509 }
1510};
1511
1512} // namespace
1513
1514void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1515 MLIRContext *context) {
1516 results.add<ExtractElementFromIndexCast>(context);
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// GatherOp
1521//===----------------------------------------------------------------------===//
1522
1523void GatherOp::getAsmResultNames(
1524 function_ref<void(Value, StringRef)> setNameFn) {
1525 setNameFn(getResult(), "gather");
1526}
1527
1528/// Return the inferred result type for a gatherOp where:
1529/// - sourceType is the type of the source tensor gathered from
1530/// - indicesType is the type of the indices used to gather
1531/// - gatherDims are the dims along which the gather occurs.
1532/// Return a full rank or ranked-reduced variant of the type depending on
1533/// the value of rankReduced.
1534///
1535/// The leading dimensions of the index tensor give the result tensor its
1536/// leading dimensions.
1537/// The trailing dimensions of the result tensor are obtained from the source
1538/// tensor by setting the dimensions specified in gather_dims to `1` (if
1539/// rankedReduced is false), or skipping them (otherwise).
1540RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1541 RankedTensorType indicesType,
1542 ArrayRef<int64_t> gatherDims,
1543 bool rankReduced) {
1544 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1545 resultShape.reserve(resultShape.size() + sourceType.getRank());
1546 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1547 if (llvm::binary_search(gatherDims, idx)) {
1548 if (!rankReduced)
1549 resultShape.push_back(1);
1550 continue;
1551 }
1552 resultShape.push_back(sourceType.getDimSize(idx));
1553 }
1554 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1555}
1556
1557static LogicalResult
1560 StringRef gatherOrScatter, StringRef sourceOrDest) {
1561 if (dims.empty())
1562 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1563
1564 int64_t numGatherDims = dims.size();
1565 if (numGatherDims > rank)
1566 return op->emitOpError(gatherOrScatter)
1567 << "_dims overflow " << sourceOrDest << " rank";
1568 if (indices.empty() || indices.back() != numGatherDims)
1569 return op->emitOpError(gatherOrScatter)
1570 << "_dims length must match the size of last dimension of indices";
1571 for (int64_t val : dims) {
1572 if (val < 0)
1573 return op->emitOpError(gatherOrScatter)
1574 << "_dims value must be non-negative";
1575 if (val >= rank)
1576 return op->emitOpError(gatherOrScatter)
1577 << "_dims value must be smaller than " << sourceOrDest << " rank";
1578 }
1579 for (int64_t i = 1; i < numGatherDims; ++i) {
1580 if (dims[i - 1] >= dims[i])
1581 return op->emitOpError(gatherOrScatter)
1582 << "_dims values must be strictly increasing";
1583 }
1584 return success();
1585}
1586
1587LogicalResult GatherOp::verify() {
1588 int64_t sourceRank = getSourceType().getRank();
1589 ArrayRef<int64_t> gatherDims = getGatherDims();
1590 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1591 getIndicesType().getShape(), sourceRank,
1592 "gather", "source")))
1593 return failure();
1594
1595 RankedTensorType expectedResultType = GatherOp::inferResultType(
1596 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1597 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1599 if (getResultType() != expectedResultType &&
1600 getResultType() != expectedRankReducedResultType) {
1601 return emitOpError("result type "
1602 "mismatch: "
1603 "expected ")
1604 << expectedResultType << " or its rank-reduced variant "
1605 << expectedRankReducedResultType << " (got: " << getResultType()
1606 << ")";
1607 }
1608
1609 return success();
1610}
1611
1612OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1613 if (OpFoldResult reshapedSource = reshapeConstantSource(
1614 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1615 getResult().getType()))
1616 return reshapedSource;
1617 return {};
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// InsertOp
1622//===----------------------------------------------------------------------===//
1623
1624void InsertOp::getAsmResultNames(
1625 function_ref<void(Value, StringRef)> setNameFn) {
1626 setNameFn(getResult(), "inserted");
1627}
1628
1629LogicalResult InsertOp::verify() {
1630 // Verify the # indices match if we have a ranked type.
1631 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1632 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1633 return emitOpError("incorrect number of indices");
1634 return success();
1635}
1636
1637OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1638 Attribute scalar = adaptor.getScalar();
1639 Attribute dest = adaptor.getDest();
1640 if (scalar && dest)
1641 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1642 if (scalar == splatDest.getSplatValue<Attribute>())
1643 return dest;
1644 return {};
1645}
1646
1647//===----------------------------------------------------------------------===//
1648// GenerateOp
1649//===----------------------------------------------------------------------===//
1650
1651void GenerateOp::getAsmResultNames(
1652 function_ref<void(Value, StringRef)> setNameFn) {
1653 setNameFn(getResult(), "generated");
1654}
1655
1656LogicalResult GenerateOp::reifyResultShapes(
1657 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1658 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1659 int idx = 0;
1660 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1661 if (getType().isDynamicDim(dim)) {
1662 reifiedReturnShapes[0][dim] = getOperand(idx++);
1663 } else {
1664 reifiedReturnShapes[0][dim] =
1665 builder.getIndexAttr(getType().getDimSize(dim));
1666 }
1667 }
1668 return success();
1669}
1670
1671LogicalResult GenerateOp::verify() {
1672 // Ensure that the tensor type has as many dynamic dimensions as are
1673 // specified by the operands.
1674 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1675 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1676 getOperands())))
1677 return failure();
1678 return success();
1679}
1680
1681LogicalResult GenerateOp::verifyRegions() {
1682 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1683 // Ensure that region arguments span the index space.
1684 if (!llvm::all_of(getBody().getArgumentTypes(),
1685 [](Type ty) { return ty.isIndex(); }))
1686 return emitError("all body arguments must be index");
1687 if (getBody().getNumArguments() != resultTy.getRank())
1688 return emitError("must have one body argument per input dimension");
1689
1690 // Ensure that the region yields an element of the right type.
1691 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1692
1693 if (yieldOp.getValue().getType() != resultTy.getElementType())
1694 return emitOpError(
1695 "body must be terminated with a `yield` operation of the tensor "
1696 "element type");
1697
1698 return success();
1699}
1700
1701void GenerateOp::build(
1702 OpBuilder &b, OperationState &result, Type resultTy,
1703 ValueRange dynamicExtents,
1704 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1705 build(b, result, resultTy, dynamicExtents);
1706
1707 // Build and populate body.
1708 OpBuilder::InsertionGuard guard(b);
1709 Region *bodyRegion = result.regions.front().get();
1710 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1711 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1712 SmallVector<Location, 2> argumentLocs(rank, result.location);
1713 Block *bodyBlock =
1714 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1715 bodyBuilder(b, result.location, bodyBlock->getArguments());
1716}
1717
1718namespace {
1719
1720/// Canonicalizes tensor.generate operations with a constant
1721/// operand into the equivalent operation with the operand expressed in the
1722/// result type, instead. We also insert a type cast to make sure that the
1723/// resulting IR is still well-typed.
1724struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1725 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1726
1727 LogicalResult matchAndRewrite(GenerateOp generateOp,
1728 PatternRewriter &rewriter) const final {
1729 SmallVector<Value> foldedDynamicSizes;
1730 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1731 generateOp.getType(), generateOp.getDynamicExtents(),
1732 foldedDynamicSizes);
1733
1734 // Stop here if no dynamic size was promoted to static.
1735 if (foldedTensorType == generateOp.getType())
1736 return failure();
1737
1738 auto loc = generateOp.getLoc();
1739 auto newOp =
1740 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1741 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1742 newOp.getBody().begin());
1743 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1744 generateOp.getType(), newOp);
1745 return success();
1746 }
1747};
1748
1749/// Canonicalizes the pattern of the form
1750///
1751/// %tensor = tensor.generate %x {
1752/// ^bb0(%arg0: index):
1753/// <computation>
1754/// yield %1 : index
1755/// } : tensor<?xindex>
1756/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1757///
1758/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1759/// tensor.generate operation has no side-effects.
1760struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1761 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1762
1763 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1764 PatternRewriter &rewriter) const final {
1765 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1766 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1767 return failure();
1768
1769 IRMapping mapping;
1770 Block *body = &tensorFromElements.getBody().front();
1771 mapping.map(body->getArguments(), extract.getIndices());
1772 for (auto &op : body->without_terminator())
1773 rewriter.clone(op, mapping);
1774
1775 auto yield = cast<YieldOp>(body->getTerminator());
1776
1777 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1778 return success();
1779 }
1780};
1781
1782} // namespace
1783
1784void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785 MLIRContext *context) {
1786 // TODO: Move extract pattern to tensor::ExtractOp.
1787 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1788}
1789
1790//===----------------------------------------------------------------------===//
1791// RankOp
1792//===----------------------------------------------------------------------===//
1793
1794void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1795 setNameFn(getResult(), "rank");
1796}
1797
1798OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1799 // Constant fold rank when the rank of the operand is known.
1800 auto type = getOperand().getType();
1801 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802 if (shapedType && shapedType.hasRank())
1803 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1804 return IntegerAttr();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// ReshapeOp
1809//===----------------------------------------------------------------------===//
1810
1811void ReshapeOp::getAsmResultNames(
1812 function_ref<void(Value, StringRef)> setNameFn) {
1813 setNameFn(getResult(), "reshape");
1814}
1815
1816static int64_t getNumElements(ShapedType type) {
1817 int64_t numElements = 1;
1818 for (auto dim : type.getShape())
1819 numElements *= dim;
1820 return numElements;
1821}
1822
1823LogicalResult ReshapeOp::verify() {
1824 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1825 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1826
1827 if (operandType.getElementType() != resultType.getElementType())
1828 return emitOpError("element types of source and destination tensor "
1829 "types should be the same");
1830
1831 int64_t shapeSize =
1832 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1833 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1834 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1835
1836 if (resultRankedType) {
1837 if (operandRankedType && resultRankedType.hasStaticShape() &&
1838 operandRankedType.hasStaticShape()) {
1839 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1840 return emitOpError("source and destination tensor should have the "
1841 "same number of elements");
1842 }
1843 if (ShapedType::isDynamic(shapeSize))
1844 return emitOpError("cannot use shape operand with dynamic length to "
1845 "reshape to statically-ranked tensor type");
1846 if (shapeSize != resultRankedType.getRank())
1847 return emitOpError(
1848 "length of shape operand differs from the result's tensor rank");
1849 }
1850 return success();
1851}
1852
1853OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1854 if (OpFoldResult reshapedSource = reshapeConstantSource(
1855 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1856 getResult().getType()))
1857 return reshapedSource;
1858
1859 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1860 // producer's input instead as the original tensor to reshape. This could
1861 // render such producer dead code.
1862 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1863 getSourceMutable().assign(reshapeOpProducer.getSource());
1864 return getResult();
1865 }
1866
1867 auto source = getSource();
1868 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1869 auto resultTy = dyn_cast<RankedTensorType>(getType());
1870 if (!sourceTy || !resultTy || sourceTy != resultTy)
1871 return {};
1872
1873 // If the source and result are both 0D or 1D tensors and have the same type,
1874 // the reshape has no effect, even if the tensor is dynamically shaped.
1875 if (sourceTy.getRank() <= 1)
1876 return source;
1877
1878 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1879 auto elements = fromElements.getElements();
1880 bool dynamicNoop =
1881 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1882 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1883 auto element = elements[id];
1884
1885 if (auto cst = getConstantIntValue(element)) {
1886 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1887 continue;
1888 }
1889
1890 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1891 dynamicNoop &= dimOp.getSource() == source;
1892
1893 auto cst = getConstantIntValue(dimOp.getIndex());
1894 dynamicNoop &=
1895 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1896 continue;
1897 }
1898
1899 dynamicNoop = false;
1900 break;
1901 }
1902
1903 if (dynamicNoop)
1904 return source;
1905 }
1906
1907 return {};
1908}
1909
1910//===----------------------------------------------------------------------===//
1911// Reassociative reshape ops
1912//===----------------------------------------------------------------------===//
1913
1914void CollapseShapeOp::getAsmResultNames(
1915 function_ref<void(Value, StringRef)> setNameFn) {
1916 setNameFn(getResult(), "collapsed");
1917}
1918
1919void ExpandShapeOp::getAsmResultNames(
1920 function_ref<void(Value, StringRef)> setNameFn) {
1921 setNameFn(getResult(), "expanded");
1922}
1923
1924int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1925 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1926 "invalid resultDim");
1927 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1928 if (llvm::is_contained(it.value(), resultDim))
1929 return it.index();
1930 llvm_unreachable("could not find reassociation group");
1931}
1932
1933FailureOr<SmallVector<OpFoldResult>>
1934ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1935 RankedTensorType expandedType,
1936 ArrayRef<ReassociationIndices> reassociation,
1937 ArrayRef<OpFoldResult> inputShape) {
1938 std::optional<SmallVector<OpFoldResult>> outputShape =
1939 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1940 inputShape);
1941 if (!outputShape)
1942 return failure();
1943 return *outputShape;
1944}
1945
1946SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1947 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1948}
1949
1950void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1951 Type resultType, Value src,
1952 ArrayRef<ReassociationIndices> reassociation,
1953 ArrayRef<OpFoldResult> outputShape) {
1954 auto [staticOutputShape, dynamicOutputShape] =
1955 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1956 build(builder, result, cast<RankedTensorType>(resultType), src,
1957 getReassociationIndicesAttribute(builder, reassociation),
1958 dynamicOutputShape, staticOutputShape);
1959}
1960
1961void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1962 Type resultType, Value src,
1963 ArrayRef<ReassociationIndices> reassociation) {
1964 SmallVector<OpFoldResult> inputShape =
1965 getMixedSizes(builder, result.location, src);
1966 auto tensorResultTy = cast<RankedTensorType>(resultType);
1967 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1968 builder, result.location, tensorResultTy, reassociation, inputShape);
1969 SmallVector<OpFoldResult> outputShapeOrEmpty;
1970 if (succeeded(outputShape)) {
1971 outputShapeOrEmpty = *outputShape;
1972 }
1973 build(builder, result, tensorResultTy, src, reassociation,
1974 outputShapeOrEmpty);
1975}
1976
1977SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1978 return getSymbolLessAffineMaps(getReassociationExprs());
1979}
1980SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1982 getReassociationIndices());
1983}
1984
1985SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1986 return getSymbolLessAffineMaps(getReassociationExprs());
1987}
1988SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1990 getReassociationIndices());
1991}
1992
1993RankedTensorType CollapseShapeOp::inferCollapsedType(
1994 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1995 return inferCollapsedType(
1997 type.getContext(), reassociation)));
1998}
1999
2000/// Compute the RankedTensorType obtained by applying `reassociation` to
2001/// `type`.
2002RankedTensorType
2003CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2004 ArrayRef<AffineMap> reassociation) {
2005 auto shape = type.getShape();
2006 SmallVector<int64_t, 4> newShape;
2007 newShape.reserve(reassociation.size());
2008
2009 // Use the fact that reassociation is valid to simplify the logic: only use
2010 // each map's rank.
2011 assert(isReassociationValid(reassociation) && "invalid reassociation");
2012 unsigned currentDim = 0;
2013 for (AffineMap m : reassociation) {
2014 unsigned dim = m.getNumResults();
2015 auto band = shape.slice(currentDim, dim);
2016 int64_t size = 1;
2017 if (llvm::is_contained(band, ShapedType::kDynamic))
2018 size = ShapedType::kDynamic;
2019 else
2020 for (unsigned d = 0; d < dim; ++d)
2021 size *= shape[currentDim + d];
2022 newShape.push_back(size);
2023 currentDim += dim;
2024 }
2025
2026 return RankedTensorType::get(newShape, type.getElementType());
2027}
2028
2029void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2030 ArrayRef<ReassociationIndices> reassociation,
2031 ArrayRef<NamedAttribute> attrs) {
2032 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2033 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2034 auto resultType =
2035 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2036 srcType.getEncoding());
2037 result.addAttribute(getReassociationAttrStrName(),
2038 getReassociationIndicesAttribute(b, reassociation));
2039 build(b, result, resultType, src, attrs);
2040}
2041
2042template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2043 TensorReshapeOp, ExpandShapeOp>::value>
2044static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2045 RankedTensorType expandedType,
2046 RankedTensorType collapsedType) {
2047 if (failed(
2048 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2049 return failure();
2050
2051 // Reshape must preserve the number of elements when statically known.
2052 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2053 int64_t expandedNumElements = expandedType.getNumElements();
2054 int64_t collapsedNumElements = collapsedType.getNumElements();
2055 if (expandedNumElements != collapsedNumElements) {
2056 return op.emitOpError("number of elements must be preserved: ")
2057 << expandedNumElements << " != " << collapsedNumElements;
2058 }
2059 }
2060
2061 auto maps = op.getReassociationMaps();
2062 RankedTensorType expectedType =
2063 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2064 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2065 return op.emitOpError("expected collapsed type to be ")
2066 << expectedType << ", but got " << collapsedType;
2067 return success();
2068}
2069
2070LogicalResult ExpandShapeOp::verify() {
2071 RankedTensorType srcType = getSrc().getType();
2072 RankedTensorType resultType = getResult().getType();
2073
2074 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2075 return emitOpError("expected number of static shape dims to be equal to "
2076 "the output rank (")
2077 << resultType.getRank() << ") but found "
2078 << getStaticOutputShape().size() << " inputs instead";
2079
2080 if ((int64_t)getOutputShape().size() !=
2081 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2082 return emitOpError("mismatch in dynamic dims in output_shape and "
2083 "static_output_shape: static_output_shape has ")
2084 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2085 << " dynamic dims while output_shape has " << getOutputShape().size()
2086 << " values";
2087
2088 // Verify that the number of dynamic dims in output_shape matches the number
2089 // of dynamic dims in the result type.
2090 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2091 getOutputShape())))
2092 return failure();
2093
2094 // Verify if provided output shapes are in agreement with output type.
2095 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2096 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2097 for (auto [pos, shape] : llvm::enumerate(resShape))
2098 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2099 return emitOpError("invalid output shape provided at pos ") << pos;
2100
2101 return verifyTensorReshapeOp(*this, resultType, srcType);
2102}
2103
2104LogicalResult CollapseShapeOp::verify() {
2105 CollapseShapeOp op = *this;
2106 if (llvm::any_of(op.getReassociationIndices(),
2107 [](ReassociationIndices group) { return group.empty(); })) {
2108 return op.emitOpError("reassociation indices must not be empty");
2109 }
2110 RankedTensorType srcType = op.getSrc().getType();
2111 RankedTensorType resultType = op.getResult().getType();
2112
2113 return verifyTensorReshapeOp(op, srcType, resultType);
2114}
2115
2116namespace {
2117/// Reshape of a splat constant can be replaced with a constant of the result
2118/// type.
2119template <typename TensorReshapeOp>
2120struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2121 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2122 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2123 PatternRewriter &rewriter) const override {
2124 DenseElementsAttr attr;
2125 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2126 return failure();
2127 if (!attr || !attr.isSplat())
2128 return failure();
2129 // DenseElementsAttr requires a static shape; skip folding for dynamic
2130 // result types.
2131 if (!reshapeOp.getResultType().hasStaticShape())
2132 return failure();
2133 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2134 reshapeOp.getResultType(), attr.getRawData());
2135 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2136 return success();
2137 }
2138};
2139
2140// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2141template <typename TensorReshapeOp>
2142class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2143public:
2144 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2145
2146 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2147 PatternRewriter &rewriter) const override {
2148 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2149 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2150 return failure();
2151
2152 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2153 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2154 return success();
2155 }
2156};
2157
2158/// Reshape of a FromElements can be replaced with a FromElements of the
2159/// result type
2160template <typename TensorReshapeOp>
2161struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2162 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2163 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2164 PatternRewriter &rewriter) const override {
2165 auto fromElements =
2166 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2167 if (!fromElements)
2168 return failure();
2169
2170 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2171
2172 if (!shapedTy.hasStaticShape())
2173 return failure();
2174
2175 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2176 fromElements.getElements());
2177 return success();
2178 }
2179};
2180
2181// Fold CastOp into CollapseShapeOp when adding static information.
2182struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2183 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2184
2185 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2186 PatternRewriter &rewriter) const override {
2187 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2188 if (!tensor::canFoldIntoConsumerOp(castOp))
2189 return failure();
2190
2191 RankedTensorType srcType =
2192 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2193 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2194 srcType, collapseShapeOp.getReassociationMaps());
2195
2196 if (newResultType == collapseShapeOp.getResultType()) {
2197 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2198 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2199 });
2200 } else {
2201 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2202 newResultType, castOp.getSource(),
2203 collapseShapeOp.getReassociation());
2204 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2205 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2206 }
2207 return success();
2208 }
2209};
2210
2211/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2212/// matching constant output_shape operands of the expand. This makes the
2213/// `tensor.expand_shape` more static and creates a consumer cast that can be
2214/// propagated further.
2215struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2216 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2217
2218 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2219 PatternRewriter &rewriter) const override {
2220 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2221 if (!canFoldIntoConsumerOp(castOp))
2222 return failure();
2223
2224 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2225 SmallVector<ReassociationIndices, 4> reassoc =
2226 expandOp.getReassociationIndices();
2227
2228 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2229 SmallVector<Value> dynamicOutputShape;
2230 auto outputIt = expandOp.getOutputShape().begin();
2231
2232 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2233 for (uint64_t outDim : innerReassoc) {
2234 if (ShapedType::isStatic(newOutputShape[outDim]))
2235 continue;
2236
2237 // If the cast's src type is dynamic, don't infer any of the
2238 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2239 // least one of the expanded dimensions to be dynamic if the input is
2240 // dynamic.
2241 Value val = *outputIt;
2242 ++outputIt;
2243 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2244 dynamicOutputShape.push_back(val);
2245 continue;
2246 }
2247
2248 APInt cst;
2249 if (matchPattern(val, m_ConstantInt(&cst))) {
2250 newOutputShape[outDim] = cst.getSExtValue();
2251 } else {
2252 dynamicOutputShape.push_back(val);
2253 }
2254 }
2255 }
2256
2257 // Couldn't match any values, nothing to change
2258 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2259 return failure();
2260
2261 // Calculate the input shape from the output
2262 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2263 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2264 for (auto outDim : reassoc[inDim]) {
2265 auto ofr = newOutputShape[outDim];
2266 if (ShapedType::isDynamic(ofr)) {
2267 newInputShape[inDim] = ShapedType::kDynamic;
2268 break;
2269 }
2270 newInputShape[inDim] *= ofr;
2271 }
2272 }
2273
2274 SmallVector<OpFoldResult> outputOfr =
2275 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2276 auto inputType = RankedTensorType::get(
2277 newInputShape, expandOp.getSrcType().getElementType());
2278 auto outputType = RankedTensorType::get(
2279 newOutputShape, expandOp.getSrcType().getElementType());
2280 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2281 expandOp.getSrc());
2282 auto newExpand = ExpandShapeOp::create(
2283 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2284 expandOp.getReassociationIndices(), outputOfr);
2285 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2286 newExpand.getResult());
2287 return success();
2288 }
2289};
2290} // namespace
2291
2292void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2293 MLIRContext *context) {
2294 results.add<
2295 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2296 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2297 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2298 FoldReshapeWithSplat<ExpandShapeOp>,
2299 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2300}
2301
2302void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2303 MLIRContext *context) {
2304 results.add<
2305 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2306 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2307 tensor::DimOp, RankedTensorType>,
2308 FoldReshapeWithConstant<CollapseShapeOp>,
2309 FoldReshapeWithSplat<CollapseShapeOp>,
2310 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2311 context);
2312}
2313
2314OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2316 adaptor.getOperands());
2317}
2318
2319OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2321 adaptor.getOperands());
2322}
2323
2324//===----------------------------------------------------------------------===//
2325// ExtractSliceOp
2326//===----------------------------------------------------------------------===//
2327
2328void ExtractSliceOp::getAsmResultNames(
2329 function_ref<void(Value, StringRef)> setNameFn) {
2330 setNameFn(getResult(), "extracted_slice");
2331}
2332
2333/// An extract_slice result type can be inferred, when it is not
2334/// rank-reduced, from the source type and the static representation of
2335/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2336RankedTensorType
2337ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2338 ArrayRef<int64_t> staticSizes) {
2339 // An extract_slice op may specify only a leading subset of offset/sizes/
2340 // strides in which case we complete with offset=0, sizes from memref type
2341 // and strides=1.
2342 assert(static_cast<int64_t>(staticSizes.size()) ==
2343 sourceTensorType.getRank() &&
2344 "unexpected staticSizes not equal to rank of source");
2345 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2346 sourceTensorType.getEncoding());
2347}
2348
2349// TODO: This uses neither offsets nor strides!
2350RankedTensorType
2351ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2352 ArrayRef<OpFoldResult> sizes) {
2353 SmallVector<int64_t> staticSizes;
2354 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2355
2356 assert(static_cast<int64_t>(staticSizes.size()) ==
2357 sourceTensorType.getRank() &&
2358 "unexpected staticSizes not equal to rank of source");
2359 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2360 sourceTensorType.getEncoding());
2361}
2362
2363/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2364/// number of sizes), drop as many size 1 as needed to produce an inferred
2365/// type with the desired rank.
2366///
2367/// Note that there may be multiple ways to compute this rank-reduced type:
2368/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2369///
2370/// To disambiguate, this function always drops the first 1 sizes occurrences.
2371RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2372 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2373 ArrayRef<int64_t> sizes) {
2374 // Type inferred in the absence of rank-reducing behavior.
2375 auto inferredType = llvm::cast<RankedTensorType>(
2376 inferResultType(sourceRankedTensorType, sizes));
2377 int rankDiff = inferredType.getRank() - desiredResultRank;
2378 if (rankDiff > 0) {
2379 auto shape = inferredType.getShape();
2380 llvm::SmallBitVector dimsToProject =
2381 getPositionsOfShapeOne(rankDiff, shape);
2382 SmallVector<int64_t> projectedShape;
2383 // Best effort rank-reducing: drop 1s in order.
2384 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2385 if (!dimsToProject.test(pos))
2386 projectedShape.push_back(shape[pos]);
2387 inferredType =
2388 RankedTensorType::get(projectedShape, inferredType.getElementType());
2389 }
2390 return inferredType;
2391}
2392
2393RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2394 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2395 ArrayRef<OpFoldResult> sizes) {
2396 SmallVector<int64_t> staticSizes;
2397 SmallVector<Value> dynamicSizes;
2398 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2399 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2400 desiredResultRank, sourceRankedTensorType, staticSizes);
2401}
2402
2403/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2404/// result type. If the type passed is nullptr, it is inferred.
2405void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2406 RankedTensorType resultType, Value source,
2407 ArrayRef<OpFoldResult> offsets,
2408 ArrayRef<OpFoldResult> sizes,
2409 ArrayRef<OpFoldResult> strides,
2410 ArrayRef<NamedAttribute> attrs) {
2411 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2412 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2413 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2414 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2415 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2416 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2417 // Structuring implementation this way avoids duplication between builders.
2418 if (!resultType) {
2419 resultType = llvm::cast<RankedTensorType>(
2420 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2421 }
2422 result.addAttributes(attrs);
2423 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2424 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2425 b.getDenseI64ArrayAttr(staticSizes),
2426 b.getDenseI64ArrayAttr(staticStrides));
2427}
2428
2429/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2430/// result type.
2431void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2432 ArrayRef<OpFoldResult> offsets,
2433 ArrayRef<OpFoldResult> sizes,
2434 ArrayRef<OpFoldResult> strides,
2435 ArrayRef<NamedAttribute> attrs) {
2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2437}
2438
2439/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2440/// a Range vector.
2441void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2442 ArrayRef<Range> ranges,
2443 ArrayRef<NamedAttribute> attrs) {
2444 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2445 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2446}
2447
2448/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2449/// the type passed is nullptr, it is inferred.
2450void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2451 RankedTensorType resultType, Value source,
2452 ValueRange offsets, ValueRange sizes,
2453 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2454 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2455 offsets, [](Value v) -> OpFoldResult { return v; });
2456 SmallVector<OpFoldResult> sizeValues =
2457 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2458 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2459 strides, [](Value v) -> OpFoldResult { return v; });
2460 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2461}
2462
2463/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2464void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2465 ValueRange offsets, ValueRange sizes,
2466 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2467 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2468}
2469
2471 Operation *op,
2472 RankedTensorType expectedType) {
2473 switch (result) {
2475 return success();
2477 return op->emitError("expected rank to be smaller or equal to ")
2478 << "the other rank. ";
2480 return op->emitError("expected type to be ")
2481 << expectedType << " or a rank-reduced version. (size mismatch) ";
2483 return op->emitError("expected element type to be ")
2484 << expectedType.getElementType();
2485 default:
2486 llvm_unreachable("unexpected extract_slice op verification result");
2487 }
2488}
2489
2490/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2491/// result type, offsets set to 0 and strides set to 1.
2492void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2493 RankedTensorType resultType, Value source,
2494 ArrayRef<OpFoldResult> sizes,
2495 ArrayRef<NamedAttribute> attrs) {
2496 Attribute zeroIdxAttr = b.getIndexAttr(0);
2497 Attribute oneIdxAttr = b.getIndexAttr(1);
2498 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2499 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2500 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2501}
2502
2503/// Verifier for ExtractSliceOp.
2504LogicalResult ExtractSliceOp::verify() {
2505 RankedTensorType sourceType = getSourceType();
2506
2507 // Verify result type against inferred type.
2508 RankedTensorType expectedType =
2509 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2512 return produceSliceErrorMsg(result, *this, expectedType);
2513
2514 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2515 // to the source tensor.
2516 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2517 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2518 getStaticStrides(), /*generateErrorMessage=*/true);
2519 if (!boundsResult.isValid)
2520 return getOperation()->emitError(boundsResult.errorMessage);
2521
2522 return success();
2523}
2524
2525llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2526 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2527}
2528
2529FailureOr<Value>
2530ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2531 ArrayRef<int64_t> desiredShape) {
2532 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2533 assert(sourceTensorType && "not a ranked tensor type");
2534 auto sourceShape = sourceTensorType.getShape();
2535 if (sourceShape.equals(desiredShape))
2536 return value;
2537 auto maybeRankReductionMask =
2538 mlir::computeRankReductionMask(sourceShape, desiredShape);
2539 if (!maybeRankReductionMask)
2540 return failure();
2542 b, loc, value,
2543 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2544}
2545
2546LogicalResult ExtractSliceOp::reifyResultShapes(
2547 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2548 reifiedReturnShapes.resize(1);
2549 reifiedReturnShapes[0].reserve(getType().getRank());
2550 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2551 llvm::SmallBitVector droppedDims = getDroppedDims();
2552 for (const auto &size : enumerate(mixedSizes)) {
2553 if (droppedDims.test(size.index()))
2554 continue;
2555 reifiedReturnShapes[0].push_back(size.value());
2556 }
2557 return success();
2558}
2559
2560namespace {
2561/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2562/// This essentially pushes memref_cast past its consuming slice when
2563/// `canFoldIntoConsumerOp` is true.
2564///
2565/// Example:
2566/// ```
2567/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2568/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2569/// tensor<3x4xf32>
2570/// ```
2571/// is rewritten into:
2572/// ```
2573/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2574/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2575/// ```
2576class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2577public:
2578 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2579
2580 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2581 PatternRewriter &rewriter) const override {
2582 // Any constant operand, just return to let the constant folder kick in.
2583 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2584 return matchPattern(operand, matchConstantIndex());
2585 }))
2586 return failure();
2587
2588 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2589 if (!castOp)
2590 return failure();
2591
2592 if (!canFoldIntoConsumerOp(castOp))
2593 return failure();
2594
2595 // Pattern does not apply if the produced op would not verify.
2596 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2597 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2598 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2599 sliceOp.getStaticStrides());
2600 if (!sliceResult.isValid)
2601 return failure();
2602
2603 // Create folded extract.
2604 Location loc = sliceOp.getLoc();
2605 Value newResult = ExtractSliceOp::create(
2606 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2607 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2608 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2609 sliceOp.getStaticStrides());
2610 rewriter.replaceOp(sliceOp, newResult);
2611 return success();
2612 }
2613};
2614
2615/// Slice elements from `values` into `outValues`. `counts` represents the
2616/// numbers of elements to stride in the original values for each dimension.
2617/// The output values can be used to construct a DenseElementsAttr.
2618template <typename IterTy, typename ElemTy>
2619static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2620 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2621 ArrayRef<int64_t> strides,
2622 llvm::SmallVectorImpl<ElemTy> *outValues) {
2623 assert(offsets.size() == sizes.size());
2624 assert(offsets.size() == strides.size());
2625 if (offsets.empty())
2626 return;
2627
2628 int64_t offset = offsets.front();
2629 int64_t size = sizes.front();
2630 int64_t stride = strides.front();
2631 if (offsets.size() == 1) {
2632 for (int64_t i = 0; i < size; ++i, offset += stride)
2633 outValues->push_back(*(values + offset));
2634
2635 return;
2636 }
2637
2638 for (int64_t i = 0; i < size; ++i, offset += stride) {
2639 auto begin = values + offset * counts.front();
2640 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2641 offsets.drop_front(), sizes.drop_front(),
2642 strides.drop_front(), outValues);
2643 }
2644}
2645
2646/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2647/// folded operation might introduce more constant data; Users can control
2648/// their heuristics by the control function.
2649class ConstantOpExtractSliceFolder final
2650 : public OpRewritePattern<ExtractSliceOp> {
2651public:
2652 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2653
2654 ConstantOpExtractSliceFolder(MLIRContext *context,
2656 : OpRewritePattern<ExtractSliceOp>(context),
2657 controlFn(std::move(controlFn)) {}
2658
2659 LogicalResult matchAndRewrite(ExtractSliceOp op,
2660 PatternRewriter &rewriter) const override {
2661 DenseElementsAttr attr;
2662 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2663 return failure();
2664
2665 // A constant splat is handled by fold().
2666 if (attr.isSplat())
2667 return failure();
2668
2669 // Dynamic result shape is not supported.
2670 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2671 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2672 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2673 return failure();
2674
2675 // Customized control over the folding.
2676 if (!controlFn(op))
2677 return failure();
2678
2679 int64_t count = sourceType.getNumElements();
2680 if (count == 0)
2681 return failure();
2682
2683 // Check if there are any dynamic parts, which are not supported.
2684 auto offsets = op.getStaticOffsets();
2685 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2686 return failure();
2687 auto sizes = op.getStaticSizes();
2688 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2689 return failure();
2690 auto strides = op.getStaticStrides();
2691 if (llvm::is_contained(strides, ShapedType::kDynamic))
2692 return failure();
2693
2694 // Compute the stride for each dimension.
2695 SmallVector<int64_t> counts;
2696 ArrayRef<int64_t> shape = sourceType.getShape();
2697 counts.reserve(shape.size());
2698 for (int64_t v : shape) {
2699 count = count / v;
2700 counts.push_back(count);
2701 }
2702
2703 // Slice the elements and construct a new attribute.
2704 SmallVector<Attribute> outValues;
2705 outValues.reserve(resultType.getNumElements());
2706 sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
2707 strides, &outValues);
2708 auto newAttr = DenseElementsAttr::get(resultType, outValues);
2709 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2710 return success();
2711 }
2712
2713private:
2714 /// This additionally controls whether the fold happens or not. Users can
2715 /// impose their heuristics in the function.
2717};
2718
2719} // namespace
2720
2722 RewritePatternSet &patterns,
2723 const ControlConstantExtractSliceFusionFn &controlFn) {
2724 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2725}
2726
2727/// Return the canonical type of the result of an extract_slice op.
2729 RankedTensorType operator()(ExtractSliceOp op,
2730 ArrayRef<OpFoldResult> mixedOffsets,
2731 ArrayRef<OpFoldResult> mixedSizes,
2732 ArrayRef<OpFoldResult> mixedStrides) {
2733 // Infer a tensor type without taking into account any rank reductions.
2734 RankedTensorType nonReducedType =
2735 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2736
2737 // Directly return the non-rank reduced type if there are no dropped
2738 // dims.
2739 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2740 if (droppedDims.none())
2741 return nonReducedType;
2742
2743 // Build the reduced shape, preserving the original rank reduction pattern.
2744 SmallVector<int64_t> targetShape;
2745 for (auto i : llvm::seq<int64_t>(mixedSizes.size()))
2746 if (!droppedDims.test(i))
2747 targetShape.push_back(nonReducedType.getDimSize(i));
2748
2749 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2750 nonReducedType.getEncoding());
2751 }
2752};
2753
2754/// A canonicalizer wrapper to replace ExtractSliceOps.
2756 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2757 ExtractSliceOp newOp) {
2758 Value replacement = newOp.getResult();
2759 if (replacement.getType() != op.getType())
2760 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2761 replacement);
2762 rewriter.replaceOp(op, replacement);
2763 }
2764};
2765
2766void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2767 MLIRContext *context) {
2768 results.add<
2769 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2770 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2771 ExtractSliceOpCastFolder>(context);
2772}
2773
2774//
2775static LogicalResult
2776foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2777 ShapedType shapedType) {
2778 OpBuilder b(op.getContext());
2779 for (OpFoldResult ofr : op.getMixedOffsets())
2780 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2781 return failure();
2782 // Rank-reducing noops only need to inspect the leading dimensions:
2783 // llvm::zip is appropriate.
2784 auto shape = shapedType.getShape();
2785 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2786 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2787 return failure();
2788 for (OpFoldResult ofr : op.getMixedStrides())
2789 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2790 return failure();
2791 return success();
2792}
2793
2794/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2795/// slice, we can return the InsertSliceOp's source directly.
2796// TODO: This only checks the immediate producer; extend to go up the
2797// insert/extract chain if the slices are disjoint.
2798static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2799 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2800
2801 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2802 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2803 insertOp.isSameAs(extractOp, isSame))
2804 return insertOp.getSource();
2805
2806 return {};
2807}
2808
2809OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2810 if (OpFoldResult reshapedSource = reshapeConstantSource(
2811 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2812 getResult().getType()))
2813 return reshapedSource;
2814 if (getSourceType() == getType() &&
2816 return this->getSource();
2817 if (Value slice = foldExtractAfterInsertSlice(*this))
2818 return slice;
2819
2820 return OpFoldResult();
2821}
2822
2824 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2825 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2826 unsigned rank = rankedTensorType.getRank();
2827 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2829 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2830 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2831 offsets, sizes, strides);
2832}
2833
2834//===----------------------------------------------------------------------===//
2835// InsertSliceOp
2836//===----------------------------------------------------------------------===//
2837
2838void InsertSliceOp::getAsmResultNames(
2839 function_ref<void(Value, StringRef)> setNameFn) {
2840 setNameFn(getResult(), "inserted_slice");
2841}
2842
2843// Build a InsertSliceOp with mixed static and dynamic entries.
2844void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2845 Value dest, ArrayRef<OpFoldResult> offsets,
2847 ArrayRef<OpFoldResult> strides,
2849 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2850 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2851 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2852 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2853 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2854 result.addAttributes(attrs);
2855 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2856 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2857 b.getDenseI64ArrayAttr(staticSizes),
2858 b.getDenseI64ArrayAttr(staticStrides));
2859}
2860
2861/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2862/// Range vector.
2863void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2864 Value dest, ArrayRef<Range> ranges,
2865 ArrayRef<NamedAttribute> attrs) {
2866 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2867 build(b, result, source, dest, offsets, sizes, strides, attrs);
2868}
2869
2870// Build a InsertSliceOp with dynamic entries.
2871void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2872 Value dest, ValueRange offsets, ValueRange sizes,
2873 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2874 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2875 offsets, [](Value v) -> OpFoldResult { return v; });
2876 SmallVector<OpFoldResult> sizeValues =
2877 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2878 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2879 strides, [](Value v) -> OpFoldResult { return v; });
2880 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2881}
2882
2883/// Rank-reducing type verification for both InsertSliceOp and
2884/// ParallelInsertSliceOp.
2886 RankedTensorType srcType, RankedTensorType dstType,
2887 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2888 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2889 // insert_slice is the inverse of extract_slice, use the same type
2890 // inference.
2891 RankedTensorType expected =
2892 ExtractSliceOp::inferResultType(dstType, staticSizes);
2893 if (expectedType)
2894 *expectedType = expected;
2895 return isRankReducedType(expected, srcType);
2896}
2897
2898/// Verifier for InsertSliceOp.
2899LogicalResult InsertSliceOp::verify() {
2900 // Verify result type against inferred type.
2901 RankedTensorType expectedType;
2903 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2904 getStaticSizes(), getStaticStrides(), &expectedType);
2906 return produceSliceErrorMsg(result, *this, expectedType);
2907
2908 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2909 // to the destination tensor.
2910 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2911 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2912 getStaticStrides(), /*generateErrorMessage=*/true);
2913 if (!boundsResult.isValid)
2914 return getOperation()->emitError(boundsResult.errorMessage);
2915
2916 return success();
2917}
2918
2919/// If we have two consecutive InsertSliceOp writing to the same slice, we
2920/// can mutate the second InsertSliceOp's destination to the first one's.
2921///
2922/// Example:
2923///
2924/// ```mlir
2925/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2926/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2927/// ```
2928///
2929/// folds into:
2930///
2931/// ```mlir
2932/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2933/// ```
2934///
2935/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2936static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2937 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2938
2939 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2940 if (!prevInsertOp ||
2941 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2942 !prevInsertOp.isSameAs(insertOp, isSame))
2943 return failure();
2944
2945 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2946 return success();
2947}
2948
2949/// Folds round-trip extract/insert slice op pairs.
2950/// Example:
2951/// ```mlir
2952/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2953/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2954/// ```
2955/// can be folded into %val.
2956static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2957 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2958
2959 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2960 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2961 !extractOp.isSameAs(insertOp, isSame))
2962 return nullptr;
2963
2964 return extractOp.getSource();
2965}
2966
2967OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2968 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2969 getSourceType() == getType() &&
2971 return this->getSource();
2972 if (succeeded(foldInsertAfterInsertSlice(*this)))
2973 return getResult();
2974 if (auto result = foldInsertAfterExtractSlice(*this))
2975 return result;
2976 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2977 return getDest();
2978 return OpFoldResult();
2979}
2980
2981LogicalResult InsertSliceOp::reifyResultShapes(
2982 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2983 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2984 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2985 return success();
2986}
2987
2988namespace {
2989/// Pattern to rewrite a insert_slice op with constant arguments.
2990///
2991/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2992template <typename InsertOpTy>
2993class InsertSliceOpConstantArgumentFolder final
2994 : public OpRewritePattern<InsertOpTy> {
2995public:
2996 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2997
2998 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2999 PatternRewriter &rewriter) const override {
3000 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3001 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3002 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3003
3004 // No constant operands were folded, just return;
3005 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
3006 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
3007 failed(foldDynamicStrideList(mixedStrides)))
3008 return failure();
3009
3010 // Pattern does not apply if the produced op would not verify.
3011 SliceBoundsVerificationResult sliceResult =
3012 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
3013 mixedOffsets, mixedSizes, mixedStrides);
3014 if (!sliceResult.isValid)
3015 return failure();
3016
3017 // Create the new op in canonical form.
3018 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3019 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3020 mixedSizes);
3021 Value toInsert = insertSliceOp.getSource();
3022 if (sourceType != insertSliceOp.getSourceType()) {
3023 OpBuilder::InsertionGuard g(rewriter);
3024 // The only difference between InsertSliceOp and ParallelInsertSliceOp
3025 // is that the insertion point is just before the InParallelOp in
3026 // the parallel case.
3027 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3028 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3029 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3030 sourceType, toInsert);
3031 }
3032 rewriter.replaceOpWithNewOp<InsertOpTy>(
3033 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3034 mixedSizes, mixedStrides);
3035 return success();
3036 }
3037};
3038
3039/// Fold tensor_casts with insert_slice operations. If the source or
3040/// destination tensor is a tensor_cast that removes static type information,
3041/// the cast is folded into the insert_slice operation. E.g.:
3042///
3043/// ```mlir
3044/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3045/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3046/// ```
3047///
3048/// folds into:
3049///
3050/// ```mlir
3051/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3052/// ```
3053///
3054/// Note: When folding a cast on the destination tensor, the result of the
3055/// insert_slice operation is casted to ensure that the type of the result did
3056/// not change.
3057///
3058/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3059template <typename InsertOpTy>
3060struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3061 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3062
3063 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3064 PatternRewriter &rewriter) const override {
3065 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3066 return matchPattern(operand, matchConstantIndex());
3067 }))
3068 return failure();
3069
3070 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3071 auto castOp = v.getDefiningOp<tensor::CastOp>();
3072 if (!castOp || !canFoldIntoConsumerOp(castOp))
3073 return std::nullopt;
3074 return castOp.getSource();
3075 };
3076 std::optional<Value> sourceCastSource =
3077 getSourceOfCastOp(insertSliceOp.getSource());
3078 std::optional<Value> destCastSource =
3079 getSourceOfCastOp(insertSliceOp.getDest());
3080 if (!sourceCastSource && !destCastSource)
3081 return failure();
3082
3083 auto src =
3084 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3085 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3086 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3087 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3088 if (!srcType || !dstType)
3089 return failure();
3090
3091 // The tensor.cast source could have additional static information not seen
3092 // in the insert slice op static sizes, so we ignore dynamic dims when
3093 // computing the rank reduction mask.
3094 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3095 auto rankReductionMask = computeRankReductionMask(
3096 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3097 if (!rankReductionMask.has_value())
3098 return failure();
3099 // Replace dimensions in the insert slice op with corresponding static dims
3100 // from the cast source type. If the insert slice sizes have static dims
3101 // that are not static in the tensor.cast source (i.e., when the cast op
3102 // casts a dynamic dim to static), the dim should not be replaced, and the
3103 // pattern will fail later in `verifyInsertSliceOp`.
3104 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3105 int64_t rankReducedIdx = 0;
3106 for (auto [idx, size] : enumerate(staticSizes)) {
3107 if (!rankReductionMask.value().contains(idx) &&
3108 !srcType.isDynamicDim(rankReducedIdx)) {
3109 mixedSizes[idx] = getAsIndexOpFoldResult(
3110 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3111 size = srcType.getDimSize(rankReducedIdx++);
3112 }
3113 }
3114
3115 // Pattern does not apply if the produced op would not verify.
3116 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3117 staticSizes, insertSliceOp.getStaticStrides()) !=
3118 SliceVerificationResult::Success)
3119 return failure();
3120 SliceBoundsVerificationResult sliceResult =
3121 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3122 mixedSizes, insertSliceOp.getMixedStrides());
3123 if (!sliceResult.isValid)
3124 return failure();
3125
3126 Operation *replacement =
3127 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3128 insertSliceOp.getMixedOffsets(), mixedSizes,
3129 insertSliceOp.getMixedStrides());
3130
3131 // In the parallel case there is no result and so nothing to cast.
3132 bool isParallelInsert =
3133 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3134 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3135 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3136 insertSliceOp.getDestType(),
3137 replacement->getResult(0));
3138 }
3139 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3140 return success();
3141 }
3142};
3143
3144/// If additional static type information can be deduced from a insert_slice's
3145/// size operands, insert an explicit cast of the op's source operand. This
3146/// enables other canonicalization patterns that are matching for tensor_cast
3147/// ops such as `ForOpTensorCastFolder` in SCF.
3148///
3149/// Example:
3150///
3151/// ```mlir
3152/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3153/// : tensor<?x?xf32> into ...
3154/// ```
3155///
3156/// folds into:
3157///
3158/// ```mlir
3159/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3160/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3161/// : tensor<64x64xf32> into ...
3162/// ```
3163///
3164/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3165template <typename InsertOpTy>
3166struct InsertSliceOpSourceCastInserter final
3167 : public OpRewritePattern<InsertOpTy> {
3168 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3169
3170 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3171 PatternRewriter &rewriter) const override {
3172 RankedTensorType srcType = insertSliceOp.getSourceType();
3173 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3174 return failure();
3175 SmallVector<int64_t> newSrcShape(srcType.getShape());
3176 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3177 if (std::optional<int64_t> constInt =
3178 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3179 // Bail on invalid IR.
3180 if (*constInt < 0)
3181 return failure();
3182 newSrcShape[i] = *constInt;
3183 }
3184 }
3185 if (!hasValidSizesOffsets(newSrcShape))
3186 return failure();
3187
3188 RankedTensorType newSrcType = RankedTensorType::get(
3189 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3190 if (srcType == newSrcType ||
3191 !preservesStaticInformation(srcType, newSrcType) ||
3192 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3193 return failure();
3194
3195 // newSrcType is:
3196 // 1) Different from srcType.
3197 // 2) "More static" than srcType.
3198 // 3) Cast-compatible with srcType.
3199 // Insert the cast.
3200 OpBuilder::InsertionGuard g(rewriter);
3201 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3202 // that the insertion point is just before the InParallelOp in the
3203 // parallel case.
3204 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3205 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3206 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3207 newSrcType, insertSliceOp.getSource());
3208 rewriter.replaceOpWithNewOp<InsertOpTy>(
3209 insertSliceOp, cast, insertSliceOp.getDest(),
3210 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3211 insertSliceOp.getMixedStrides());
3212 return success();
3213 }
3214};
3215} // namespace
3216
3217llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3218 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3219}
3220
3221void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3222 MLIRContext *context) {
3223 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3224 InsertSliceOpCastFolder<InsertSliceOp>,
3225 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3226}
3227
3229 Location loc,
3230 Value tensor,
3231 Value dest) {
3232 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3233 unsigned rank = rankedTensorType.getRank();
3234 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3235 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3236 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3237 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3238 sizes, strides);
3239}
3240
3241//===----------------------------------------------------------------------===//
3242// PadOp
3243//===----------------------------------------------------------------------===//
3244
3245void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3246 setNameFn(getResult(), "padded");
3247}
3248
3249LogicalResult PadOp::verify() {
3250 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3251 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3252 auto expectedType =
3253 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3254 if (!expectedType) {
3255 return emitError("failed to infer expectedType from sourceType ")
3256 << sourceType << ", specified resultType is " << resultType;
3257 }
3258 if (resultType.getRank() != expectedType.getRank()) {
3259 return emitError("specified type ")
3260 << resultType << " does not match the inferred type "
3261 << expectedType;
3262 }
3263 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3264 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3265 continue;
3266 if (expectedType.isDynamicDim(i))
3267 continue;
3268 return emitError("specified type ")
3269 << resultType << " does not match the inferred type "
3270 << expectedType;
3271 }
3272
3273 return success();
3274}
3275
3276LogicalResult PadOp::verifyRegions() {
3277 auto &region = getRegion();
3278 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3279 Block &block = region.front();
3280 if (block.getNumArguments() != rank)
3281 return emitError("expected the block to have ") << rank << " arguments";
3282
3283 // Note: the number and type of yield values are checked in the YieldOp.
3284 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3285 if (!en.value().isIndex())
3286 return emitOpError("expected block argument ")
3287 << (en.index() + 1) << " to be an index";
3288 }
3289
3290 // Ensure that the region yields an element of the right type.
3291 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3292 if (yieldOp.getValue().getType() !=
3293 llvm::cast<ShapedType>(getType()).getElementType())
3294 return emitOpError("expected yield type to match shape element type");
3295
3296 return success();
3297}
3298
3299RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3300 ArrayRef<int64_t> staticLow,
3301 ArrayRef<int64_t> staticHigh,
3302 ArrayRef<int64_t> resultShape) {
3303 unsigned rank = sourceType.getRank();
3304 if (staticLow.size() != rank)
3305 return RankedTensorType();
3306 if (staticHigh.size() != rank)
3307 return RankedTensorType();
3308 if (!resultShape.empty() && resultShape.size() != rank)
3309 return RankedTensorType();
3310
3311 SmallVector<int64_t, 4> inferredShape;
3312 for (auto i : llvm::seq<unsigned>(0, rank)) {
3313 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3314 staticHigh[i] == ShapedType::kDynamic) {
3315 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3316 : resultShape[i]);
3317 } else {
3318 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3319 assert((resultShape.empty() || size == resultShape[i] ||
3320 resultShape[i] == ShapedType::kDynamic) &&
3321 "mismatch between inferred shape and result shape");
3322 inferredShape.push_back(size);
3323 }
3324 }
3325
3326 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3327}
3328
3329void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3330 Value source, ArrayRef<int64_t> staticLow,
3331 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3332 bool nofold, ArrayRef<NamedAttribute> attrs) {
3333 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3334 if (!resultType)
3335 resultType = inferResultType(sourceType, staticLow, staticHigh);
3336 result.addAttributes(attrs);
3337 build(b, result, resultType, source, low, high,
3338 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3339 nofold ? b.getUnitAttr() : UnitAttr());
3340}
3341
3342void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3343 Value source, ValueRange low, ValueRange high, bool nofold,
3344 ArrayRef<NamedAttribute> attrs) {
3345 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3346 unsigned rank = sourceType.getRank();
3347 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3348 build(b, result, resultType, source, staticVector, staticVector, low, high,
3349 nofold, attrs);
3350}
3351
3352void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3353 Value source, ArrayRef<OpFoldResult> low,
3354 ArrayRef<OpFoldResult> high, bool nofold,
3355 ArrayRef<NamedAttribute> attrs) {
3356 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3357 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3358 SmallVector<int64_t, 4> staticLow, staticHigh;
3359 // staticLow and staticHigh have full information of the padding config.
3360 // This will grow staticLow and staticHigh with 1 value. If the config is
3361 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3362 // value as well.
3363 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3364 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3365 if (!resultType) {
3366 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3367 }
3368 assert(llvm::isa<RankedTensorType>(resultType));
3369 result.addAttributes(attrs);
3370 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3371 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3372 nofold ? b.getUnitAttr() : UnitAttr());
3373}
3374
3375void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3376 Value source, ArrayRef<OpFoldResult> low,
3377 ArrayRef<OpFoldResult> high, Value constantPadValue,
3378 bool nofold, ArrayRef<NamedAttribute> attrs) {
3379 build(b, result, resultType, source, low, high, nofold, attrs);
3380
3381 // Add a region and a block to yield the pad value.
3382 Region *region = result.regions[0].get();
3383 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3384 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3385 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3386
3387 // `builder.createBlock` changes the insertion point within the block. Create
3388 // a guard to reset the insertion point of the builder after it is destroyed.
3389 OpBuilder::InsertionGuard guard(b);
3390 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3391 tensor::YieldOp::create(b, result.location, constantPadValue);
3392}
3393
3394llvm::SmallBitVector PadOp::getPaddedDims() {
3395 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3396 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3397 for (const auto &en : enumerate(paddingWidths))
3398 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3399 paddedDims.set(en.index());
3400 };
3401 extractPaddedDims(getMixedLowPad());
3402 extractPaddedDims(getMixedHighPad());
3403 return paddedDims;
3404}
3405
3406namespace {
3407// Folds tensor.pad when padding is static zeros and the attribute
3408// doesn't request otherwise.
3409struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3410 using OpRewritePattern<PadOp>::OpRewritePattern;
3411
3412 LogicalResult matchAndRewrite(PadOp padTensorOp,
3413 PatternRewriter &rewriter) const override {
3414 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3415 return failure();
3416 if (padTensorOp.getNofold())
3417 return failure();
3418 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3419 padTensorOp, padTensorOp.getResult().getType(),
3420 padTensorOp.getSource());
3421 return success();
3422 }
3423};
3424
3425// Fold CastOp into PadOp when adding static information.
3426struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3427 using OpRewritePattern<PadOp>::OpRewritePattern;
3428
3429 LogicalResult matchAndRewrite(PadOp padTensorOp,
3430 PatternRewriter &rewriter) const override {
3431 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3432 if (!tensor::canFoldIntoConsumerOp(castOp))
3433 return failure();
3434
3435 auto newResultType = PadOp::inferResultType(
3436 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3437 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3438 padTensorOp.getResultType().getShape());
3439
3440 if (newResultType == padTensorOp.getResultType()) {
3441 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3442 padTensorOp.getSourceMutable().assign(castOp.getSource());
3443 });
3444 } else {
3445 auto newOp = PadOp::create(
3446 rewriter, padTensorOp->getLoc(), newResultType,
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3450 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3451 IRMapping mapper;
3452 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3453
3454 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3455 padTensorOp, padTensorOp.getResultType(), newOp);
3456 }
3457 return success();
3458 }
3459};
3460
3461// Fold CastOp using the result of PadOp back into the latter if it adds
3462// static information.
3463struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3464 using OpRewritePattern<PadOp>::OpRewritePattern;
3465
3466 LogicalResult matchAndRewrite(PadOp padTensorOp,
3467 PatternRewriter &rewriter) const override {
3468 if (!padTensorOp.getResult().hasOneUse())
3469 return failure();
3470 auto tensorCastOp =
3471 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3472 if (!tensorCastOp)
3473 return failure();
3474 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3475 tensorCastOp.getDest().getType()))
3476 return failure();
3477
3478 auto replacementOp = PadOp::create(
3479 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3480 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3481 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3482 padTensorOp.getHigh(), padTensorOp.getNofold(),
3483 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3484 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3485
3486 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3487 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3488 return success();
3489 }
3490};
3491
3492/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3493/// different dimensions. The pattern applies if the following preconditions
3494/// hold:
3495/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3496/// 2) the tensor::ExtractSliceOps have only unit-strides,
3497/// 3) the tensor::PadOps perform only high-padding,
3498/// 4) the tensor::PadOps have the same constant padding value,
3499/// 5) the tensor::PadOps do not have common padding dimensions,
3500/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3501/// zero-offset for every dimension.
3502/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3503/// the
3504/// padded source dimensions.
3505///
3506/// Example:
3507///
3508/// ```mlir
3509/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3510/// : tensor<64x64xf32> to tensor<?x64xf32>
3511/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3512/// } : tensor<?x64xf32> to tensor<8x64xf32>
3513/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3514/// : tensor<8x64xf32> to tensor<8x?xf32>
3515/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3516/// } : tensor<8x?xf32> to tensor<8x4xf32>
3517/// ```
3518///
3519/// folds into:
3520///
3521/// ```mlir
3522/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3523/// : tensor<64x64xf32> to tensor<?x?xf32>
3524/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3525/// } : tensor<?x?xf32> to tensor<8x4xf32>
3526/// ```
3527struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3528 using OpRewritePattern<PadOp>::OpRewritePattern;
3529
3530 LogicalResult matchAndRewrite(PadOp padOp,
3531 PatternRewriter &rewriter) const override {
3532 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3533 if (!innerSliceOp)
3534 return failure();
3535 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3536 if (!outerPadOp || outerPadOp.getNofold())
3537 return failure();
3538 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3539 if (!outerSliceOp)
3540 return failure();
3541
3542 // 1) Fail if the chain is rank-reducing.
3543 int64_t rank = padOp.getSourceType().getRank();
3544 if (outerSliceOp.getSourceType().getRank() != rank) {
3545 return rewriter.notifyMatchFailure(padOp,
3546 "cannot fold rank-reducing chain");
3547 }
3548
3549 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3550 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3551 return rewriter.notifyMatchFailure(
3552 padOp, "cannot fold non-unit stride ExtractSliceOps");
3553 }
3554
3555 // 3) Fail if the tensor::PadOps have non-zero low padding.
3556 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3557 return rewriter.notifyMatchFailure(padOp,
3558 "cannot fold PadOps with low padding");
3559 }
3560
3561 // 4) Fail if the tensor::PadOps padding values do not match.
3562 Attribute innerAttr, outerAttr;
3563 Value innerValue = padOp.getConstantPaddingValue();
3564 Value outerValue = outerPadOp.getConstantPaddingValue();
3565 if (!innerValue || !outerValue ||
3566 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3567 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3568 innerAttr != outerAttr) {
3569 return rewriter.notifyMatchFailure(
3570 padOp, "cannot fold PadOps with different padding values");
3571 }
3572
3573 // 5) Fail if a dimension is padded by both tensor::PadOps.
3574 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3575 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3576 if (innerDims.anyCommon(outerDims)) {
3577 return rewriter.notifyMatchFailure(
3578 padOp, "cannot fold PadOps with common padding dimensions");
3579 }
3580
3581 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3582 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3583 // for every dimension, and use the offset the other pair. Fail if no
3584 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3585 // exists.
3586 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3587 for (auto en : enumerate(newOffsets)) {
3588 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3589 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3590 if (!innerDims.test(en.index()) &&
3591 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3592 en.value() = outerOffset;
3593 continue;
3594 }
3595 if (!outerDims.test(en.index()) &&
3596 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3597 en.value() = innerOffset;
3598 continue;
3599 }
3600 return rewriter.notifyMatchFailure(
3601 padOp, "cannot find zero-offset and zero-padding pair");
3602 }
3603
3604 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3605 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3606 // outer tensor::PadOp and fail if the size of the inner
3607 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3608 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3609 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3610 for (auto en : enumerate(newSizes)) {
3611 if (!outerDims.test(en.index()))
3612 continue;
3613 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3614 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3615 assert(ShapedType::isStatic(sourceSize) &&
3616 "expected padded dimension to have a static size");
3617 if (getConstantIntValue(sliceSize) != sourceSize) {
3618 return rewriter.notifyMatchFailure(
3619 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3620 "match the size of the outer padding");
3621 }
3622 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3623 }
3624
3625 // Combine the high paddings of the two tensor::PadOps.
3626 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3627 for (auto en : enumerate(newHighPad)) {
3628 if (innerDims.test(en.index()))
3629 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3630 if (outerDims.test(en.index()))
3631 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3632 }
3633
3634 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3635 // the two paddings in one step.
3636 auto newSliceOp = ExtractSliceOp::create(
3637 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3638 newSizes, innerSliceOp.getMixedStrides());
3639 auto newPadOp = PadOp::create(
3640 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3641 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3642 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3643 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3644 newPadOp.getRegion().begin());
3645 rewriter.replaceOp(padOp, newPadOp.getResult());
3646 return success();
3647 }
3648};
3649
3650struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3651 using OpRewritePattern<PadOp>::OpRewritePattern;
3652
3653 LogicalResult matchAndRewrite(PadOp padTensorOp,
3654 PatternRewriter &rewriter) const override {
3655 Value input = padTensorOp.getSource();
3656 if (!llvm::isa<RankedTensorType>(input.getType()))
3657 return failure();
3658 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3659 auto inputRank = inputDims.size();
3660
3661 auto oldResultType =
3662 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3663 if (!oldResultType)
3664 return failure();
3665
3666 auto outputDims = oldResultType.getShape();
3667
3668 // Extract the static info from the high and low operands.
3669 SmallVector<int64_t> constOperandsLow;
3670 SmallVector<Value> newLows;
3671 for (auto operand : padTensorOp.getLow()) {
3672 APSInt intOp;
3673 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3674 constOperandsLow.push_back(ShapedType::kDynamic);
3675 newLows.push_back(operand);
3676 continue;
3677 }
3678 constOperandsLow.push_back(intOp.getExtValue());
3679 }
3680 SmallVector<int64_t> constOperandsHigh;
3681 SmallVector<Value> newHighs;
3682 for (auto operand : padTensorOp.getHigh()) {
3683 APSInt intOp;
3684 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3685 constOperandsHigh.push_back(ShapedType::kDynamic);
3686 newHighs.push_back(operand);
3687 continue;
3688 }
3689 constOperandsHigh.push_back(intOp.getExtValue());
3690 }
3691
3692 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3693 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3694
3695 // Verify the op is well-formed.
3696 if (inputDims.size() != outputDims.size() ||
3697 inputDims.size() != constLow.size() ||
3698 inputDims.size() != constHigh.size())
3699 return failure();
3700
3701 auto lowCount = 0;
3702 auto highCount = 0;
3703 for (size_t i = 0; i < inputRank; i++) {
3704 if (constLow[i] == ShapedType::kDynamic)
3705 constLow[i] = constOperandsLow[lowCount++];
3706 if (constHigh[i] == ShapedType::kDynamic)
3707 constHigh[i] = constOperandsHigh[highCount++];
3708 }
3709
3710 auto staticLow = ArrayRef<int64_t>(constLow);
3711 auto staticHigh = ArrayRef<int64_t>(constHigh);
3712
3713 // Calculate the output sizes with the static information.
3714 SmallVector<int64_t> newOutDims;
3715 for (size_t i = 0; i < inputRank; i++) {
3716 if (outputDims[i] == ShapedType::kDynamic) {
3717 newOutDims.push_back(
3718 (staticLow[i] == ShapedType::kDynamic ||
3719 staticHigh[i] == ShapedType::kDynamic ||
3720 inputDims[i] == ShapedType::kDynamic
3721 ? ShapedType::kDynamic
3722 : inputDims[i] + staticLow[i] + staticHigh[i]));
3723 } else {
3724 newOutDims.push_back(outputDims[i]);
3725 }
3726 }
3727
3728 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3729 llvm::all_of(newOutDims,
3730 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3731 return failure();
3732
3733 // Rewrite the op using the new static type.
3734 auto newResultType = RankedTensorType::get(
3735 newOutDims, padTensorOp.getType().getElementType());
3736 auto newOp = PadOp::create(
3737 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3738 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3739 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3740
3741 IRMapping mapper;
3742 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3743 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3744 newOp);
3745
3746 return success();
3747 }
3748};
3749
3750/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3751///
3752/// Example:
3753///
3754/// ```mlir
3755/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3756/// tensor.yield %val
3757/// } : tensor<1x2xf32> to tensor<2x5xf32>
3758/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3759/// tensor.yield %val
3760/// } : tensor<1x5xf32> to tensor<5x7xf32>
3761/// ```
3762///
3763/// folds into:
3764///
3765/// ```mlir
3766/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3767/// tensor.yield %val
3768/// } : tensor<1x2xf32> to tensor<5x7xf32>
3769/// ```
3770struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3771 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3772
3773 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3774 PatternRewriter &rewriter) const override {
3775 if (padOp.getNofold()) {
3776 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3777 }
3778
3779 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3780 if (!producerPad || producerPad.getNofold()) {
3781 return rewriter.notifyMatchFailure(
3782 padOp, "producer is not a foldable tensor.pad op");
3783 }
3784
3785 // Fail if the tensor::PadOps padding values do not match.
3786 Value consumerPadValue = padOp.getConstantPaddingValue();
3787 Value producerPadValue = producerPad.getConstantPaddingValue();
3788 if (!consumerPadValue || !producerPadValue ||
3789 consumerPadValue != producerPadValue) {
3790 return rewriter.notifyMatchFailure(
3791 padOp,
3792 "cannot fold PadOps with different or non-constant padding values");
3793 }
3794
3795 Location loc = padOp.getLoc();
3796 AffineExpr d0, d1;
3797 bindDims(rewriter.getContext(), d0, d1);
3798
3799 // Combine the low/high paddings of the two tensor::PadOps.
3800 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3801 ArrayRef<OpFoldResult> producerPaddings) {
3802 SmallVector<OpFoldResult> sumPaddings;
3803 for (auto [consumerIndex, producerIndex] :
3804 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3805 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3806 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3807 }
3808 return sumPaddings;
3809 };
3810
3811 SmallVector<OpFoldResult> newHighPad =
3812 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3813 SmallVector<OpFoldResult> newLowPad =
3814 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3815
3816 auto newPadOp = tensor::PadOp::create(
3817 rewriter, padOp.getLoc(), padOp.getResultType(),
3818 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3819 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3820 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3821 newPadOp.getRegion().begin());
3822 rewriter.replaceOp(padOp, newPadOp.getResult());
3823 return success();
3824 }
3825};
3826
3827} // namespace
3828
3829LogicalResult
3830PadOp::reifyResultShapes(OpBuilder &b,
3831 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3832 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3833 SmallVector<OpFoldResult> lp = getMixedLowPad();
3834 SmallVector<OpFoldResult> hp = getMixedHighPad();
3835 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3836 if (!getType().isDynamicDim(i)) {
3837 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3838 continue;
3839 }
3840 Location loc = getLoc();
3841 Value dim = b.createOrFold<tensor::DimOp>(
3842 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3843
3844 AffineExpr d0, d1, d2;
3845 bindDims(b.getContext(), d0, d1, d2);
3846 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3847 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3848 }
3849 return success();
3850}
3851
3852void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3853 MLIRContext *context) {
3854 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3855 FoldOrthogonalPaddings, FoldStaticPadding,
3856 FoldConsecutiveConstantPadding>(context);
3857}
3858
3859/// Return the padding value of the PadOp if it constant. In this context,
3860/// "constant" means an actual constant or "defined outside of the block".
3861///
3862/// Values are considered constant in three cases:
3863/// - A ConstantLike value.
3864/// - A basic block argument from a different block.
3865/// - A value defined outside of the block.
3866///
3867/// If the padding value is not constant, an empty Value is returned.
3868Value PadOp::getConstantPaddingValue() {
3869 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3870 if (!yieldOp)
3871 return {};
3872 Value padValue = yieldOp.getValue();
3873 // Check if yield value is a constant.
3874 if (matchPattern(padValue, m_Constant()))
3875 return padValue;
3876 // Check if yield value is defined inside the PadOp block.
3877 if (padValue.getParentBlock() == &getRegion().front())
3878 return {};
3879 // Else: Yield value defined outside of the PadOp block.
3880 return padValue;
3881}
3882
3883OpFoldResult PadOp::fold(FoldAdaptor) {
3884 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3885 !getNofold())
3886 return getSource();
3887 return {};
3888}
3889
3890//===----------------------------------------------------------------------===//
3891// ParallelInsertSliceOp
3892//===----------------------------------------------------------------------===//
3893
3894OpResult ParallelInsertSliceOp::getTiedOpResult() {
3895 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3896 for (const auto &it :
3897 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3898 Operation &nextOp = it.value();
3899 if (&nextOp == getOperation())
3900 return parallelCombiningParent.getParentResult(it.index());
3901 }
3902 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3903}
3904
3905// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3906void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3907 Value source, Value dest,
3908 ArrayRef<OpFoldResult> offsets,
3909 ArrayRef<OpFoldResult> sizes,
3910 ArrayRef<OpFoldResult> strides,
3911 ArrayRef<NamedAttribute> attrs) {
3912 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3913 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3914 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3915 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3916 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3917 result.addAttributes(attrs);
3918 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3919 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3920 b.getDenseI64ArrayAttr(staticSizes),
3921 b.getDenseI64ArrayAttr(staticStrides));
3922}
3923
3924/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3925/// packed into a Range vector.
3926void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3927 Value source, Value dest,
3928 ArrayRef<Range> ranges,
3929 ArrayRef<NamedAttribute> attrs) {
3930 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3931 build(b, result, source, dest, offsets, sizes, strides, attrs);
3932}
3933
3934// Build a ParallelInsertSliceOp with dynamic entries.
3935void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3936 Value source, Value dest, ValueRange offsets,
3937 ValueRange sizes, ValueRange strides,
3938 ArrayRef<NamedAttribute> attrs) {
3939 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3940 offsets, [](Value v) -> OpFoldResult { return v; });
3941 SmallVector<OpFoldResult> sizeValues =
3942 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3943 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3944 strides, [](Value v) -> OpFoldResult { return v; });
3945 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3946}
3947
3948// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3949// to 0, strides set to 1 and inferred result type.
3950void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3951 Value dest, ArrayRef<OpFoldResult> sizes,
3952 ArrayRef<NamedAttribute> attrs) {
3953 Attribute zeroIdxAttr = b.getIndexAttr(0);
3954 Attribute oneIdxAttr = b.getIndexAttr(1);
3955 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3956 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3957 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3958}
3959
3960LogicalResult ParallelInsertSliceOp::verify() {
3961 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3962 return this->emitError("expected InParallelOpInterface parent, got:")
3963 << *(getOperation()->getParentOp());
3964
3965 // Verify result type against inferred type.
3966 RankedTensorType expectedType;
3968 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3969 getStaticSizes(), getStaticStrides(), &expectedType);
3971 return produceSliceErrorMsg(result, *this, expectedType);
3972
3973 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3974 // to the destination tensor.
3975 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3976 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3977 getStaticStrides(), /*generateErrorMessage=*/true);
3978 if (!boundsResult.isValid)
3979 return getOperation()->emitError(boundsResult.errorMessage);
3980
3981 return success();
3982}
3983
3984void ParallelInsertSliceOp::getCanonicalizationPatterns(
3985 RewritePatternSet &results, MLIRContext *context) {
3986 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3987 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3988 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3989}
3990
3991llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3992 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3993}
3994
3995// ParallelCombiningOpInterface implementation.
3996MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3997 return getDestMutable();
3998}
3999
4000Operation *ParallelInsertSliceOp::getIteratingParent() {
4001 // Return the parent InParallelOpInterface's parent.
4002 if (auto combiningOp =
4003 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4004 return combiningOp->getParentOp();
4005 return nullptr;
4006}
4007
4008//===----------------------------------------------------------------------===//
4009// ScatterOp
4010//===----------------------------------------------------------------------===//
4011
4012void ScatterOp::getAsmResultNames(
4013 function_ref<void(Value, StringRef)> setNameFn) {
4014 setNameFn(getResult(), "scatter");
4015}
4016
4017LogicalResult ScatterOp::verify() {
4018 int64_t destRank = getDestType().getRank();
4019 ArrayRef<int64_t> scatterDims = getScatterDims();
4020 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
4021 getIndicesType().getShape(), destRank,
4022 "scatter", "dest")))
4023 return failure();
4024
4025 if (!getUnique())
4026 return emitOpError("requires 'unique' attribute to be set");
4027 // TODO: we could also check statically that there are fewer leading index
4028 // tensor dims than the dest dims. If this is not the case, the unique
4029 // attribute cannot be true.
4030
4031 // Use the GatherOp::inferResultType on the `dest` type and verify the
4032 // expected type matches the source type.
4033 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4034 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
4035 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4036 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
4037 if (getSourceType() != expectedSourceType &&
4038 getSourceType() != expectedRankReducedSourceType) {
4039 return emitOpError("source type "
4040 "mismatch: "
4041 "expected ")
4042 << expectedSourceType << " or its rank-reduced variant "
4043 << expectedRankReducedSourceType << " (got: " << getSourceType()
4044 << ")";
4045 }
4046
4047 return success();
4048}
4049
4050//===----------------------------------------------------------------------===//
4051// SplatOp
4052//===----------------------------------------------------------------------===//
4053
4054void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4055 Type aggregateType, ValueRange dynamicSizes) {
4056 build(builder, result, aggregateType, element, dynamicSizes);
4057}
4058
4059void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4060 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4061 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4062 build(builder, result, aggregateType, element, dynamicSizes);
4063}
4064
4065void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4066 ArrayRef<OpFoldResult> sizes) {
4067 SmallVector<int64_t> staticShape;
4068 SmallVector<Value> dynamicSizes;
4069 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4070 build(builder, result, element, staticShape, dynamicSizes);
4071}
4072
4073void SplatOp::getAsmResultNames(
4074 function_ref<void(Value, StringRef)> setNameFn) {
4075 setNameFn(getResult(), "splat");
4076}
4077
4078LogicalResult SplatOp::verify() {
4079 return verifyDynamicDimensionCount(getOperation(), getType(),
4080 getDynamicSizes());
4081}
4082
4083LogicalResult
4084SplatOp::reifyResultShapes(OpBuilder &builder,
4085 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4086 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4087 unsigned ctr = 0;
4088 for (int64_t i = 0; i < getType().getRank(); ++i) {
4089 if (getType().isDynamicDim(i)) {
4090 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4091 } else {
4092 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4093 }
4094 }
4095 return success();
4096}
4097
4098OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4099 auto constOperand = adaptor.getInput();
4100 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4101 return {};
4102
4103 // Do not fold if the splat is not statically shaped
4104 if (!getType().hasStaticShape())
4105 return {};
4106
4107 // SplatElementsAttr::get treats single value for second arg as being a
4108 // splat.
4109 return SplatElementsAttr::get(getType(), {constOperand});
4110}
4111
4112//===----------------------------------------------------------------------===//
4113// Common Canonicalizers and Folders.
4114//===----------------------------------------------------------------------===//
4115static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4116 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4117 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4118 // might need special handling of attached regions.
4119 if (isa<InsertSliceOp>(op.getOperation()) ||
4120 isa<LoopLikeOpInterface>(op.getOperation()))
4121 return false;
4122
4124}
4125
4126/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4127/// the `tensor.cast` has source that is more static than the consuming op.
4128///
4129/// Example:
4130/// ```mlir
4131/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4132/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4133/// ```
4134///
4135/// folds into:
4136///
4137/// ```mlir
4138/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4139/// ```
4140/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4141/// can add the pattern to their canonicalizers.
4143 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4145 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4146
4147 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4148 PatternRewriter &rewriter) const override {
4149
4150 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4151 // for that instead.
4152 if (!foldTensorCastPrecondition(op) ||
4153 isa<linalg::RelayoutOpInterface>(*op))
4154 return failure();
4155
4156 SmallVector<Type> newResultTypes(op->getResultTypes());
4157 SmallVector<Value> newOperands =
4158 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4159
4160 // Clone op
4161 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4162
4163 SmallVector<Value, 4> replacements;
4164 replacements.reserve(newOp->getNumResults());
4165 for (auto [oldResult, newResult] :
4166 llvm::zip(op->getResults(), newOp->getResults())) {
4167 if (newResult.getType() != oldResult.getType()) {
4168 replacements.push_back(tensor::CastOp::create(
4169 rewriter, op->getLoc(), oldResult.getType(), newResult));
4170 } else {
4171 replacements.push_back(newResult);
4172 }
4173 }
4174 rewriter.replaceOp(op, replacements);
4175
4176 return success();
4177 }
4178};
4179
4180//===----------------------------------------------------------------------===//
4181// TensorDialect
4182//===----------------------------------------------------------------------===//
4183
4184void TensorDialect::getCanonicalizationPatterns(
4185 RewritePatternSet &results) const {
4186 results.add<FoldTensorCastProducerOp>(getContext());
4187}
4188
4189//===----------------------------------------------------------------------===//
4190// TableGen'd op method definitions
4191//===----------------------------------------------------------------------===//
4192
4193#define GET_OP_CLASSES
4194#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:382
MLIRContext * getContext() const
Definition Builders.h:56
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:444
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:77
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:93
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.