MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/Repeated.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/MathExtras.h"
42#include <optional>
43
44using namespace mlir;
45using namespace mlir::tensor;
46
47/// Materialize a single constant operation from a given attribute value with
48/// the desired resultant type.
49Operation *TensorDialect::materializeConstant(OpBuilder &builder,
50 Attribute value, Type type,
51 Location loc) {
52 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 return op;
54 if (complex::ConstantOp::isBuildableWith(value, type))
55 return complex::ConstantOp::create(builder, loc, type,
56 llvm::cast<ArrayAttr>(value));
57 return nullptr;
58}
59
61 int64_t dim) {
62 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
65
66 return builder.getIndexAttr(tensorType.getDimSize(dim));
67}
68
70 Location loc, Value value) {
71 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
73 for (int64_t i = 0; i < tensorType.getRank(); ++i)
74 result.push_back(getMixedSize(builder, loc, value, i));
75 return result;
76}
77
79 OpResult opResult) {
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
81 assert(tensorType && "expected tensor type");
82
83 // If the op has a destination, it implements DestinationStyleOpInterface and
84 // we can query the destination operand from that interface.
85 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
86 if (destOp)
87 return destOp.getTiedOpOperand(opResult)->get();
88
89 // Otherwise, create a new destination tensor with the same shape.
91 b.setInsertionPoint(opResult.getDefiningOp());
92
93 // Compute sizes.
95 if (!tensorType.hasStaticShape()) {
96 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
97 ReifiedRankedShapedTypeDims reifiedShapes;
98 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
99 return failure();
100 mixedSizes = reifiedShapes[opResult.getResultNumber()];
101 } else {
102 // Static shape: Take static sizes directly.
103 for (int64_t sz : tensorType.getShape())
104 mixedSizes.push_back(b.getIndexAttr(sz));
105 }
106
107 // Create empty tensor.
108 Value emptyTensor =
109 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
110 return emptyTensor;
111}
112
114 Operation *op,
116 for (OpResult opResult : op->getResults()) {
117 if (llvm::isa<TensorType>(opResult.getType())) {
118 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
119 if (failed(destination))
120 return failure();
121 result.push_back(*destination);
122 }
123 }
124 return success();
125}
126
128 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
129 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
130 return rtp1.getShape() == rtp2.getShape() &&
131 rtp1.getElementType() == rtp2.getElementType();
132 return false;
133 }
134 return tp1 == tp2; // default implementation
135}
136
137/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
138/// rank-extending tensor.insert_slice op.
139static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
140 ArrayRef<OpFoldResult> mixedSizes) {
141 llvm::SmallBitVector droppedDims(mixedSizes.size());
142 int64_t shapePos = reducedShape.size() - 1;
143
144 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
145 size_t idx = mixedSizes.size() - size.index() - 1;
146 // Rank-reduced dims must have a static unit dimension.
147 bool isStaticUnitSize =
148 isa<Attribute>(size.value()) &&
149 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
150
151 if (shapePos < 0) {
152 // There are no more dims in the reduced shape. All remaining sizes must
153 // be rank-reduced dims.
154 assert(isStaticUnitSize && "expected unit dim");
155 droppedDims.set(idx);
156 continue;
157 }
158
159 // Dim is preserved if the size is not a static 1.
160 if (!isStaticUnitSize) {
161 --shapePos;
162 continue;
163 }
164
165 // Dim is preserved if the reduced shape dim is also 1.
166 if (reducedShape[shapePos] == 1) {
167 --shapePos;
168 continue;
169 }
170
171 // Otherwise: Dim is dropped.
172 droppedDims.set(idx);
173 }
174
175 assert(shapePos < 0 && "dimension mismatch");
176 return droppedDims;
177}
178
179/// Given a ranked tensor type and a range of values that defines its dynamic
180/// dimension sizes, turn all dynamic sizes that have a constant value into
181/// static dimension sizes.
182static RankedTensorType
183foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
184 SmallVector<Value> &foldedDynamicSizes) {
185 SmallVector<int64_t> staticShape(type.getShape());
186 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
187 "incorrect number of dynamic sizes");
188
189 // Compute new static and dynamic sizes.
190 unsigned ctr = 0;
191 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
192 if (type.isDynamicDim(i)) {
193 Value dynamicSize = dynamicSizes[ctr++];
194 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
195 if (cst.has_value()) {
196 // Dynamic size must be non-negative.
197 if (cst.value() < 0) {
198 foldedDynamicSizes.push_back(dynamicSize);
199 continue;
200 }
201 staticShape[i] = *cst;
202 } else {
203 foldedDynamicSizes.push_back(dynamicSize);
204 }
205 }
206 }
207
208 return RankedTensorType::get(staticShape, type.getElementType(),
209 type.getEncoding());
210}
211
212//===----------------------------------------------------------------------===//
213// BitcastOp
214//===----------------------------------------------------------------------===//
215
216bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
217 if (inputs.size() != 1 || outputs.size() != 1)
218 return false;
219 Type a = inputs.front(), b = outputs.front();
220 auto aT = dyn_cast<TensorType>(a);
221 auto bT = dyn_cast<TensorType>(b);
222 if (!aT || !bT)
223 return false;
224
225 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
226 return false;
227
228 return succeeded(verifyCompatibleShape(aT, bT));
229}
230
231namespace {
232
233/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
234/// operation.
235struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
236 using OpRewritePattern<BitcastOp>::OpRewritePattern;
237
238 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
239 PatternRewriter &rewriter) const final {
240 auto tensorBitcastOperand =
241 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
242 if (!tensorBitcastOperand)
243 return failure();
244
245 auto resultType = cast<TensorType>(tensorBitcast.getType());
246 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
247 tensorBitcastOperand.getOperand());
248 return success();
249 }
250};
251
252} // namespace
253
254void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
255 MLIRContext *context) {
256 results.add<ChainedTensorBitcast>(context);
257}
258
259//===----------------------------------------------------------------------===//
260// CastOp
261//===----------------------------------------------------------------------===//
262
263void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
264 setNameFn(getResult(), "cast");
265}
266
267/// Returns true if `target` is a ranked tensor type that preserves static
268/// information available in the `source` ranked tensor type.
270 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
271 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
272
273 // Requires RankedTensorType.
274 if (!sourceType || !targetType)
275 return false;
276
277 // Requires same elemental type.
278 if (sourceType.getElementType() != targetType.getElementType())
279 return false;
280
281 // Requires same rank.
282 if (sourceType.getRank() != targetType.getRank())
283 return false;
284
285 // Requires same encoding.
286 if (sourceType.getEncoding() != targetType.getEncoding())
287 return false;
288
289 // If cast is towards more static sizes along any dimension, don't fold.
290 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
291 if (ShapedType::isStatic(std::get<0>(t)) &&
292 ShapedType::isDynamic(std::get<1>(t)))
293 return false;
294 }
295
296 return true;
297}
298
299/// Determines whether tensor::CastOp casts to a more dynamic version of the
300/// source tensor. This is useful to fold a tensor.cast into a consuming op and
301/// implement canonicalization patterns for ops in different dialects that may
302/// consume the results of tensor.cast operations. Such foldable tensor.cast
303/// operations are typically inserted as `slice` ops and are canonicalized,
304/// to preserve the type compatibility of their uses.
305///
306/// Returns true when all conditions are met:
307/// 1. source and result are ranked tensors with same element type and rank.
308/// 2. the tensor type has more static information than the result
309///
310/// Example:
311/// ```mlir
312/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
313/// %2 = consumer %1 ... : tensor<?x?xf32> ...
314/// ```
315///
316/// folds into:
317///
318/// ```mlir
319/// %2 = consumer %0 ... : tensor<8x16xf32> ...
320/// ```
322 if (!castOp)
323 return false;
324
325 // Can fold if the source of cast has at least as much static information as
326 // its results.
327 return preservesStaticInformation(castOp.getType(),
328 castOp.getSource().getType());
329}
330
331/// Determines whether the tensor::CastOp casts to a more static version of the
332/// source tensor. This is useful to fold into a producing op and implement
333/// canonicalization patterns with the `tensor.cast` op as the root, but
334/// producer being from different dialects. Returns true when all conditions are
335/// met:
336/// 1. source and result and ranked tensors with same element type and rank.
337/// 2. the result type has more static information than the source.
338///
339/// Example:
340/// ```mlir
341/// %1 = producer ... : tensor<?x?xf32>
342/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
343/// ```
344///
345/// can be canonicalized to :
346///
347/// ```mlir
348/// %2 = producer ... : tensor<8x16xf32>
349/// ```
350/// Not all ops might be canonicalizable this way, but for those that can be,
351/// this method provides a check that it is worth doing the canonicalization.
353 if (!castOp)
354 return false;
355 return preservesStaticInformation(castOp.getSource().getType(),
356 castOp.getType());
357}
358
360 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
361 if (llvm::isa<BlockArgument>(opOperand.get()))
362 return false;
363 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
364 return castOp && canFoldIntoConsumerOp(castOp);
365 });
366}
367
369 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
370 SmallVector<Value> newOperands;
371 newOperands.reserve(op->getNumOperands());
372
373 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
374
375 // Assumes that the result has dpsInits followed by nonDpsInits.
376 int64_t dpsInitIdx = 0;
377 for (OpOperand &opOperand : op->getOpOperands()) {
378 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
379 bool fold = canFoldIntoConsumerOp(tensorCastOp);
380 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
381 if (op.isDpsInit(&opOperand) &&
382 !llvm::isa<MemRefType>(newOperands.back().getType()))
383 newResTy[dpsInitIdx++] = newOperands.back().getType();
384 }
385 return newOperands;
386}
387
388/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
389/// that can be folded.
391 bool folded = false;
392 for (OpOperand &operand : op->getOpOperands()) {
393 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
394 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
395 operand.set(castOp.getOperand());
396 folded = true;
397 }
398 }
399 return success(folded);
400}
401
402bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
403 if (inputs.size() != 1 || outputs.size() != 1)
404 return false;
405 Type a = inputs.front(), b = outputs.front();
406 auto aT = llvm::dyn_cast<TensorType>(a);
407 auto bT = llvm::dyn_cast<TensorType>(b);
408 if (!aT || !bT)
409 return false;
410
411 if (aT.getElementType() != bT.getElementType())
412 return false;
413
414 return succeeded(verifyCompatibleShape(aT, bT));
415}
416
417/// Compute a TensorType that has the joined shape knowledge of the two
418/// given TensorTypes. The element types need to match.
420 assert(one.getElementType() == two.getElementType());
421
422 if (!one.hasRank())
423 return two;
424 if (!two.hasRank())
425 return one;
426
427 int64_t rank = one.getRank();
428 if (rank != two.getRank())
429 return {};
430
432 join.reserve(rank);
433 for (int64_t i = 0; i < rank; ++i) {
434 if (one.isDynamicDim(i)) {
435 join.push_back(two.getDimSize(i));
436 continue;
437 }
438 if (two.isDynamicDim(i)) {
439 join.push_back(one.getDimSize(i));
440 continue;
441 }
442 if (one.getDimSize(i) != two.getDimSize(i))
443 return {};
444 join.push_back(one.getDimSize(i));
445 }
446 return RankedTensorType::get(join, one.getElementType());
447}
448
449namespace {
450
451/// Replaces chains of two tensor.cast operations by a single tensor.cast
452/// operation if doing so does not remove runtime constraints.
453struct ChainedTensorCast : public OpRewritePattern<CastOp> {
454 using OpRewritePattern<CastOp>::OpRewritePattern;
455
456 LogicalResult matchAndRewrite(CastOp tensorCast,
457 PatternRewriter &rewriter) const final {
458 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
459
460 if (!tensorCastOperand)
461 return failure();
462
463 auto sourceType =
464 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
465 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
466 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
467
468 // We can remove the intermediate cast if joining all three produces the
469 // same result as just joining the source and result shapes.
470 auto firstJoin =
471 joinShapes(joinShapes(sourceType, intermediateType), resultType);
472
473 // The join might not exist if the cast sequence would fail at runtime.
474 if (!firstJoin)
475 return failure();
476
477 // The newJoin always exists if the above join exists, it might just contain
478 // less information. If so, we cannot drop the intermediate cast, as doing
479 // so would remove runtime checks.
480 auto newJoin = joinShapes(sourceType, resultType);
481 if (firstJoin != newJoin)
482 return failure();
483
484 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
485 tensorCastOperand.getOperand());
486 return success();
487 }
488};
489
490/// Fold tensor.cast into tesor.extract_slice producer.
491/// Example:
492/// ```
493/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
494/// tensor<128x512xf32> to tensor<?x512xf32>
495/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
496/// ```
497/// ->
498/// ```
499/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
500/// tensor<128x512xf32> to tensor<16x512xf32>
501/// ```
502struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
503 using OpRewritePattern<CastOp>::OpRewritePattern;
504
505 LogicalResult matchAndRewrite(CastOp tensorCast,
506 PatternRewriter &rewriter) const final {
507 auto extractOperand =
508 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
509
510 // Cannot fold cast to unranked tensor.
511 auto rankedResultType =
512 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
513 if (!rankedResultType)
514 return failure();
515
516 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
517 rankedResultType.getShape() ==
518 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
519 .getShape())
520 return failure();
521
522 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
523 auto dimMask = computeRankReductionMask(
524 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
525 size_t dimIndex = 0;
526 for (size_t i = 0, e = sizes.size(); i < e; i++) {
527 if (dimMask && dimMask->count(i))
528 continue;
529 int64_t dim = rankedResultType.getShape()[dimIndex++];
530 if (ShapedType::isDynamic(dim))
531 continue;
532 sizes[i] = rewriter.getIndexAttr(dim);
533 }
534
535 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
536 tensorCast, rankedResultType, extractOperand.getSource(),
537 extractOperand.getMixedOffsets(), sizes,
538 extractOperand.getMixedStrides());
539 return success();
540 }
541};
542
543} // namespace
544
545void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
546 MLIRContext *context) {
547 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
548}
549
550//===----------------------------------------------------------------------===//
551// ConcatOp
552//===----------------------------------------------------------------------===//
553
554RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
555 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
556 auto tensorTypes =
557 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
558 int64_t concatRank = tensorTypes[0].getRank();
559
560 // The concatenation dim must be in the range [0, rank).
561 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
562
563 SmallVector<int64_t> sizes(concatRank);
564 for (int64_t i = 0, e = concatRank; i < e; ++i) {
565 if (i == dim)
566 continue;
567 SaturatedInteger size;
568 for (auto tensorType : tensorTypes)
569 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
570 sizes[i] = size.asInteger();
571 }
572 auto concatSize = SaturatedInteger::wrap(0);
573 for (auto tensorType : tensorTypes)
574 concatSize =
575 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
576 sizes[dim] = concatSize.asInteger();
577 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
578}
579
580void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
581 ValueRange inputs) {
582 FailureOr<RankedTensorType> resultType =
583 inferResultType(dim, inputs.getTypes());
584 assert(succeeded(resultType) && "failed to infer concatenation result type");
585 build(builder, result, *resultType, dim, inputs);
586}
587
588LogicalResult ConcatOp::verify() {
589 if (getInputs().size() < 1)
590 return emitOpError("requires at least one input");
591
593 for (auto input : getInputs())
594 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
595
596 RankedTensorType resultType = getResultType();
597 int64_t resultRank = getRank();
598 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
599 return type.getRank() != resultRank;
600 }))
601 return emitOpError("rank of concatenated inputs must match result rank");
602
603 Type resultElementType = resultType.getElementType();
604 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
605 return type.getElementType() != resultElementType;
606 }))
607 return emitOpError("inputs and result element type must match");
608
609 int64_t dim = getDim();
610 if (dim >= resultRank)
611 return emitOpError("concatenation dim must be less than the tensor rank");
612
613 SmallVector<int64_t> sizes(resultRank);
614 for (int64_t i = 0, e = resultRank; i < e; ++i) {
615 if (i == dim)
616 continue;
617 SaturatedInteger size;
618 for (auto tensorType : inputTypes) {
619 FailureOr<SaturatedInteger> maybeSize =
620 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
621 if (failed(maybeSize))
622 return emitOpError("static concatenation size mismatch along ")
623 << "non-concatenated dimension " << i;
624 size = *maybeSize;
625 }
626 sizes[i] = size.asInteger();
627 }
628 auto concatSize = SaturatedInteger::wrap(0);
629 for (auto tensorType : inputTypes)
630 concatSize =
631 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
632 sizes[dim] = concatSize.asInteger();
633 auto inferredResultType =
634 RankedTensorType::get(sizes, inputTypes[0].getElementType());
635
636 for (auto [inferredSize, actualSize] :
637 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
638 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
639 ShapedType::isDynamic(actualSize);
640 if (!hasDynamic && inferredSize != actualSize)
641 return emitOpError("result type ")
642 << resultType << "does not match inferred shape "
643 << inferredResultType << " static sizes";
644 }
645
646 return success();
647}
648
649FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
650 size_t numInputs = getInputs().size();
651 uint64_t concatDim = getDim();
652
654 inputShapes.reserve(numInputs);
655 SmallVector<OpFoldResult> concatOffsets;
656 concatOffsets.reserve(numInputs);
657 SmallVector<OpFoldResult> outputShape;
658
659 AffineExpr addExpr =
660 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
661 OpFoldResult zero = builder.getIndexAttr(0);
662 Location loc = getLoc();
663 for (auto [index, input] : llvm::enumerate(getInputs())) {
664 SmallVector<OpFoldResult> inputShape =
665 tensor::getMixedSizes(builder, input.getLoc(), input);
666 if (index == 0) {
667 outputShape = inputShape;
668 concatOffsets.push_back(zero);
669 } else {
670 concatOffsets.push_back(outputShape[concatDim]);
671 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
672 builder, loc, addExpr,
673 {outputShape[concatDim], inputShape[concatDim]});
674 }
675 inputShapes.emplace_back(std::move(inputShape));
676 }
677
678 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
680
681 int64_t rank = getType().getRank();
682 OpFoldResult one = builder.getIndexAttr(1);
683 SmallVector<OpFoldResult> strides(rank, one);
684 SmallVector<OpFoldResult> offsets(rank, zero);
685 for (auto [index, input] : llvm::enumerate(getInputs())) {
686 offsets[concatDim] = concatOffsets[index];
687 auto insertSlice = tensor::InsertSliceOp::create(
688 builder, loc, input, replacement, offsets, inputShapes[index], strides);
689 replacement = insertSlice.getResult();
690 }
691 if (replacement.getType() != getType()) {
692 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
693 }
695}
696
697LogicalResult
698ConcatOp::reifyResultShapes(OpBuilder &builder,
699 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
700 ValueRange inputs = getInputs();
701 int64_t dim = getDim();
702 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
703
704 Value init = inputs[0];
705 int64_t rank = getType().getRank();
706
707 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
708
709 // Pre-populate the result sizes with as much static information as possible
710 // from the given result type, as well as the inferred result type, otherwise
711 // use the dim sizes from the first input.
712 for (int64_t i = 0; i < rank; ++i) {
713 if (i == dim)
714 continue;
715 if (!getType().isDynamicDim(i)) {
716 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
717 } else if (!inferredResultType.isDynamicDim(i)) {
718 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
719 builder, getLoc(),
720 builder.getIndexAttr(inferredResultType.getDimSize(i)));
721 } else {
722 reifiedReturnShapes[0][i] =
723 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
724 }
725 }
726
727 if (getType().isDynamicDim(dim)) {
728 // Take the sum of the input sizes along the concatenated dim.
729 AffineExpr sum = builder.getAffineDimExpr(0);
731 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
732 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
733 sum = sum + builder.getAffineDimExpr(idx + 1);
734 sizes.push_back(
735 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
736 }
737 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
738 builder, getLoc(),
739 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
740 } else {
741 // If the result shape is static along the concatenated dim, use the static
742 // shape.
743 reifiedReturnShapes[0][dim] =
744 builder.getIndexAttr(getType().getDimSize(dim));
745 }
746 return success();
747}
748
749void ConcatOp::getAsmResultNames(
750 function_ref<void(Value, StringRef)> setNameFn) {
751 setNameFn(getResult(), "concat");
752}
753
754OpFoldResult ConcatOp::fold(FoldAdaptor) {
755 ValueRange inputs = getInputs();
756 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
757 return inputs[0];
758 return {};
759}
760
761namespace {
762/// Fold a concat op with a single input to a cast.
763struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
764 using OpRewritePattern<ConcatOp>::OpRewritePattern;
765
766 LogicalResult matchAndRewrite(ConcatOp concatOp,
767 PatternRewriter &rewriter) const override {
768 if (concatOp.getInputs().size() != 1)
769 return failure();
770 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
771 concatOp.getInputs()[0]);
772 return success();
773 }
774};
775
776/// Propagate static shapes into the operands of a `tensor.concat`.
777///
778/// `tensor.concat` requires every operand to match on all dimensions except the
779/// concatenation dimension. If one operand is already static in those
780/// dimensions, the other operands may safely be refined to that same static
781/// shape.
782///
783/// Example:
784///
785/// ```mlir
786/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
787/// tensor<?x12xi32>
788/// ```
789/// ->
790/// ```mlir
791/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
792/// %2 = tensor.concat dim(0) %0, %cast :
793/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
794/// ```
795struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
796 using OpRewritePattern<ConcatOp>::OpRewritePattern;
797
798 LogicalResult matchAndRewrite(ConcatOp concatOp,
799 PatternRewriter &rewriter) const override {
800 int64_t dim = concatOp.getDim();
801 RankedTensorType inferredResultType =
802 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
803
804 // Find operands for which a more static shape can be inferred.
805 LogicalResult matched = failure();
806 // Inferred operand shapes are identical in every dimension except the
807 // concatenation dimension.
808 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
809 for (auto [operandIdx, operandType] :
810 llvm::enumerate(concatOp->getOperandTypes())) {
811 // Compute inferred type for operand.
812 inferredOperandShape[dim] =
813 cast<RankedTensorType>(operandType).getDimSize(dim);
814 auto inferredOperandType = RankedTensorType::get(
815 inferredOperandShape, inferredResultType.getElementType());
816
817 // Check if inferred type is more static.
818 if (!preservesStaticInformation(inferredOperandType, operandType)) {
819 matched = success();
820
821 // Use refined operand type and create cast from original operand.
822 auto castOp =
823 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
824 concatOp.getOperand(operandIdx));
825 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
826 concatOp->setOperand(operandIdx, castOp->getResult(0));
827 });
828 }
829 }
830
831 return matched;
832 }
833};
834
835// Ensure `tensor.concat`'s result type is at least as static as can be inferred
836// from its operand types.
837///
838/// Example:
839/// ```mlir
840/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
841/// tensor<?x?xi32>
842/// ```
843/// ->
844/// ```mlir
845/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
846/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
847/// tensor<?x?xi32>
848/// ```
849struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
850 using OpRewritePattern<ConcatOp>::OpRewritePattern;
851
852 LogicalResult matchAndRewrite(ConcatOp concatOp,
853 PatternRewriter &rewriter) const override {
854 int64_t dim = concatOp.getDim();
855 RankedTensorType inferredResultType =
856 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
857
858 // The result type should be at least as static as inferred result type.
859 if (preservesStaticInformation(inferredResultType,
860 concatOp.getResultType())) {
861 return failure();
862 }
863
864 auto newConcatOp =
865 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
866 concatOp->getOperands());
867 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
868 newConcatOp);
869
870 return success();
871 }
872};
873} // namespace
874
875void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
876 MLIRContext *context) {
877 results
878 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
879 context);
880}
881
882//===----------------------------------------------------------------------===//
883// DimOp
884//===----------------------------------------------------------------------===//
885
886void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
887 setNameFn(getResult(), "dim");
888}
889
890void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
891 int64_t index) {
892 auto loc = result.location;
893 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
894 build(builder, result, source, indexValue);
895}
896
897std::optional<int64_t> DimOp::getConstantIndex() {
899}
900
901Speculation::Speculatability DimOp::getSpeculatability() {
902 auto constantIndex = getConstantIndex();
903 if (!constantIndex)
905
906 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
907 if (!rankedSourceType)
909
910 if (rankedSourceType.getRank() <= constantIndex)
912
914}
915
916void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
917 SetIntLatticeFn setResultRange) {
918 setResultRange(getResult(),
919 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
920}
921
922OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
923 // All forms of folding require a known index.
924 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
925 if (!index)
926 return {};
927
928 // Folding for unranked types (UnrankedTensorType) is not supported.
929 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
930 if (!tensorType)
931 return {};
932
933 // Out of bound indices produce undefined behavior but are still valid IR.
934 // Don't choke on them.
935 int64_t indexVal = index.getInt();
936 if (indexVal < 0 || indexVal >= tensorType.getRank())
937 return {};
938
939 // Fold if the shape extent along the given index is known.
940 if (!tensorType.isDynamicDim(index.getInt())) {
941 Builder builder(getContext());
942 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
943 }
944
945 Operation *definingOp = getSource().getDefiningOp();
946
947 // Fold dim to the operand of tensor.generate.
948 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
949 auto resultType =
950 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
951 // The case where the type encodes the size of the dimension is handled
952 // above.
953 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
954
955 // Find the operand of the fromElements that corresponds to this index.
956 auto dynExtents = fromElements.getDynamicExtents().begin();
957 for (auto dim : resultType.getShape().take_front(index.getInt()))
958 if (ShapedType::isDynamic(dim))
959 dynExtents++;
960
961 return Value{*dynExtents};
962 }
963
964 // The size at the given index is now known to be a dynamic size.
965 unsigned unsignedIndex = index.getValue().getZExtValue();
966
967 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
968 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
969 // `resolve-shaped-type-result-dims` pass.
970 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
971 sliceOp.isDynamicSize(unsignedIndex)) {
972 return {sliceOp.getDynamicSize(unsignedIndex)};
973 }
974 }
975
976 // dim(cast) -> dim
977 if (succeeded(foldTensorCast(*this)))
978 return getResult();
979
980 return {};
981}
982
983namespace {
984/// Fold dim of a cast into the dim of the source of the tensor cast.
985struct DimOfCastOp : public OpRewritePattern<DimOp> {
986 using OpRewritePattern<DimOp>::OpRewritePattern;
987
988 LogicalResult matchAndRewrite(DimOp dimOp,
989 PatternRewriter &rewriter) const override {
990 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
991 if (!castOp)
992 return failure();
993 Value newSource = castOp.getOperand();
994 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
995 return success();
996 }
997};
998
999/// Fold dim of a destination passing style op into the dim of the corresponding
1000/// init.
1001struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1002 using OpRewritePattern<DimOp>::OpRewritePattern;
1003
1004 LogicalResult matchAndRewrite(DimOp dimOp,
1005 PatternRewriter &rewriter) const override {
1006 auto source = dimOp.getSource();
1007 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1008 if (!destOp)
1009 return failure();
1010
1011 auto resultIndex = cast<OpResult>(source).getResultNumber();
1012 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1013
1014 rewriter.modifyOpInPlace(
1015 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1016 return success();
1017 }
1018};
1019
1020/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1021/// operand.
1022struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1023 using OpRewritePattern<DimOp>::OpRewritePattern;
1024
1025 LogicalResult matchAndRewrite(DimOp dim,
1026 PatternRewriter &rewriter) const override {
1027 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1028
1029 if (!reshape)
1030 return failure();
1031
1032 // Since tensors are immutable we don't need to worry about where to place
1033 // the extract call
1034 rewriter.setInsertionPointAfter(dim);
1035 Location loc = dim.getLoc();
1036 Value extract =
1037 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1038 if (extract.getType() != dim.getType())
1039 extract =
1040 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1041 rewriter.replaceOp(dim, extract);
1042 return success();
1043 }
1044};
1045} // namespace
1046
1047void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1048 MLIRContext *context) {
1049 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// EmptyOp
1054//===----------------------------------------------------------------------===//
1055
1056void EmptyOp::build(OpBuilder &builder, OperationState &result,
1057 ArrayRef<int64_t> staticShape, Type elementType,
1058 Attribute encoding) {
1059 assert(none_of(staticShape, ShapedType::isDynamic) &&
1060 "expected only static sizes");
1061 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1062}
1063
1064void EmptyOp::build(OpBuilder &builder, OperationState &result,
1065 ArrayRef<int64_t> staticShape, Type elementType,
1066 ValueRange dynamicSizes, Attribute encoding) {
1067 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1068 build(builder, result, tensorType, dynamicSizes);
1069}
1070
1071void EmptyOp::build(OpBuilder &builder, OperationState &result,
1072 ArrayRef<OpFoldResult> sizes, Type elementType,
1073 Attribute encoding) {
1074 SmallVector<int64_t> staticShape;
1075 SmallVector<Value> dynamicSizes;
1076 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1077 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1078}
1079
1080LogicalResult EmptyOp::verify() {
1081 return verifyDynamicDimensionCount(getOperation(), getType(),
1082 getDynamicSizes());
1083}
1084
1085LogicalResult
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1087 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1089 unsigned ctr = 0;
1090 for (int64_t i = 0; i < getType().getRank(); ++i) {
1091 if (getType().isDynamicDim(i)) {
1092 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1093 } else {
1094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1095 }
1096 }
1097 return success();
1098}
1099
1100Value EmptyOp::getDynamicSize(unsigned idx) {
1101 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1102 unsigned ctr = 0;
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (getType().isDynamicDim(i))
1105 ++ctr;
1106 return getDynamicSizes()[ctr];
1107}
1108
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult> result;
1111 unsigned ctr = 0;
1112 Builder b(getContext());
1113 for (int64_t dim : getType().getShape()) {
1114 if (ShapedType::isDynamic(dim)) {
1115 result.push_back(getDynamicSizes()[ctr++]);
1116 } else {
1117 result.push_back(b.getIndexAttr(dim));
1118 }
1119 }
1120 return result;
1121}
1122
1123namespace {
1124/// Change the type of the result of a `tensor.empty` by making the result
1125/// type statically sized along dimensions that in the original operation were
1126/// defined as dynamic, but the size was defined using a `constant` op. For
1127/// example
1128///
1129/// %c5 = arith.constant 5: index
1130/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1131///
1132/// to
1133///
1134/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter) const override {
1140 SmallVector<Value> foldedDynamicSizes;
1141 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1143
1144 // Stop here if no dynamic size was promoted to static.
1145 if (foldedTensorType == op.getType())
1146 return failure();
1147
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1150 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1151 return success();
1152 }
1153};
1154
1155struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1157
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter) const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1163 return failure();
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1168 return failure();
1169 rewriter.replaceOp(dimOp,
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1171 return success();
1172 }
1173};
1174
1175/// Canonicalize
1176///
1177/// ```mlir
1178/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1179/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1180/// ```
1181///
1182/// into
1183///
1184/// ```mlir
1185/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1186/// ```
1187///
1188/// This assumes the input program is correct in terms of its shape. So it is
1189/// safe to assume that `%d0` is in fact 4.
1190struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1192
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter) const override {
1195 if (!canFoldIntoProducerOp(castOp))
1196 return failure();
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1198 if (!producer)
1199 return failure();
1200
1201 auto resultType =
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
1210 // Case 1: The empty tensor dim is static. Check that the tensor cast
1211 // result dim matches.
1212 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1213 if (ShapedType::isDynamic(newDim) ||
1214 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1215 // Something is off, the cast result shape cannot be more dynamic
1216 // than the empty tensor result shape (enforced by
1217 // `canFoldIntoProducer`). Abort for now.
1218 return rewriter.notifyMatchFailure(
1219 producer, "mismatch in static value of shape of empty tensor "
1220 "result and cast result");
1221 }
1222 newMixedSizes.push_back(attr);
1223 continue;
1224 }
1225
1226 // Case 2 : The tensor cast shape is static, but empty tensor result
1227 // shape is dynamic.
1228 if (ShapedType::isStatic(newDim)) {
1229 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1230 continue;
1231 }
1232
1233 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1234 // shape is dynamic. Use the dynamic value from the empty tensor op.
1235 newMixedSizes.push_back(currDim);
1236 }
1237
1238 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1239 resultType.getElementType(),
1240 resultType.getEncoding());
1241 return success();
1242 }
1243};
1244
1245} // namespace
1246
1247void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248 MLIRContext *context) {
1249 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1250 ReplaceEmptyTensorStaticShapeDims>(context);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// ExtractOp
1255//===----------------------------------------------------------------------===//
1256
1257namespace {
1258
1259/// Canonicalizes the pattern of the form
1260///
1261/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1262/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1263///
1264/// to
1265///
1266/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1267struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1268 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1269
1270 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1271 PatternRewriter &rewriter) const final {
1272 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1273 if (!tensorCast)
1274 return failure();
1275 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1276 return failure();
1277 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1278 extract, tensorCast.getSource(), extract.getIndices());
1279 return success();
1280 }
1281};
1282
1283/// Canonicalizes the pattern of the form
1284///
1285/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1286/// tensor<12xf64>
1287/// %extracted_element = tensor.extract %val[%c10] :
1288/// tensor<12xf64>
1289///
1290/// to
1291///
1292/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1293struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1294 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1295
1296 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1297 PatternRewriter &rewriter) const final {
1298 auto collapseOp =
1299 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1300 if (!collapseOp)
1301 return failure();
1302 if (!collapseOp.getSrcType().hasStaticShape())
1303 return failure();
1304
1305 auto sourceSizes = collapseOp.getSrcType().getShape();
1306
1307 SmallVector<Value> indices(extractOp.getIndices().begin(),
1308 extractOp.getIndices().end());
1309 SmallVector<Value> sourceIndices;
1310 for (auto [index, group] :
1311 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1312 assert(!group.empty() && "association indices groups cannot be empty");
1313 auto groupSize = group.size();
1314
1315 if (groupSize == 1) {
1316 sourceIndices.push_back(index);
1317 continue;
1318 }
1319
1320 SmallVector<int64_t> basis =
1321 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1322 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1323 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1324 llvm::append_range(sourceIndices, delinearize.getResults());
1325 }
1326 if (collapseOp.getReassociationIndices().empty()) {
1327 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1328 int64_t srcRank =
1329 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1330 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1331 rewriter, extractOp.getLoc(), zeroAffineMap,
1332 ArrayRef<OpFoldResult>{});
1333 for (int64_t i = 0; i < srcRank; i++) {
1334 sourceIndices.push_back(
1335 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1336 }
1337 }
1338
1339 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1340 extractOp, collapseOp.getSrc(), sourceIndices);
1341 return success();
1342 }
1343};
1344
1345} // namespace
1346
1347void ExtractOp::getAsmResultNames(
1348 function_ref<void(Value, StringRef)> setNameFn) {
1349 setNameFn(getResult(), "extracted");
1350}
1351
1352LogicalResult ExtractOp::verify() {
1353 // Verify the # indices match if we have a ranked type.
1354 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1355 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1356 return emitOpError("incorrect number of indices for extract_element");
1357 return success();
1358}
1359
1360/// If we have an ExtractOp consuming an InsertOp with the same
1361/// indices, we can return the InsertOp's scalar directly.
1362// TODO: This only checks the immediate producer; extend to go up the
1363// insert/extract chain if the slices are disjoint.
1364static Value foldExtractAfterInsert(ExtractOp extractOp) {
1365 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1366
1367 auto isSame = [](Value a, Value b) {
1369 };
1370 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1371 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1372 return insertOp.getScalar();
1373
1374 return {};
1375}
1376
1377OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1378 if (Attribute tensor = adaptor.getTensor()) {
1379 // If this is a splat elements attribute, simply return the value.
1380 // All of the elements of a splat attribute are the same.
1381 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1382 return splatTensor.getSplatValue<Attribute>();
1383
1384 // If this is a dense resource elements attribute, return.
1385 if (isa<DenseResourceElementsAttr>(tensor))
1386 return {};
1387 }
1388
1389 // Collect the constant indices into the tensor.
1390 SmallVector<uint64_t, 8> indices;
1391 for (Attribute indice : adaptor.getIndices()) {
1392 if (!indice || !llvm::isa<IntegerAttr>(indice))
1393 return {};
1394 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1395 }
1396
1397 // Fold extract(from_elements(...)).
1398 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1399 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1400 auto rank = tensorType.getRank();
1401 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1402 "rank mismatch");
1403 int flatIndex = 0;
1404 int stride = 1;
1405 for (int i = rank - 1; i >= 0; --i) {
1406 flatIndex += indices[i] * stride;
1407 stride *= tensorType.getDimSize(i);
1408 }
1409 // Prevent out of bounds accesses. This can happen in invalid code that
1410 // will never execute.
1411 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1412 flatIndex < 0)
1413 return {};
1414 return fromElementsOp.getElements()[flatIndex];
1415 }
1416
1417 // If this is an elements attribute, query the value at the given indices.
1418 if (Attribute tensor = adaptor.getTensor()) {
1419 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1420 if (elementsAttr && elementsAttr.isValidIndex(indices))
1421 return elementsAttr.getValues<Attribute>()[indices];
1422 }
1423
1424 if (Value result = foldExtractAfterInsert(*this))
1425 return result;
1426
1427 return {};
1428}
1429
1430void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1431 MLIRContext *context) {
1432 results.add<ExtractFromTensorCast>(context);
1433}
1434
1436 RewritePatternSet &patterns) {
1437 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1438}
1439
1440//===----------------------------------------------------------------------===//
1441// FromElementsOp
1442//===----------------------------------------------------------------------===//
1443
1444void FromElementsOp::getAsmResultNames(
1445 function_ref<void(Value, StringRef)> setNameFn) {
1446 setNameFn(getResult(), "from_elements");
1447}
1448
1449void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1450 ValueRange elements) {
1451 assert(!elements.empty() && "expected at least one element");
1452 Type resultType = RankedTensorType::get(
1453 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1454 build(builder, result, resultType, elements);
1455}
1456
1457OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1458 // DenseElementsAttr::get requires StringAttr for element types that are not
1459 // integer, index, float, or complex (e.g. vector types), but folded constants
1460 // won't be StringAttr instances. Only fold for element types directly
1461 // supported by DenseElementsAttr.
1462 Type eltType = getType().getElementType();
1463 if (!eltType.isIntOrIndexOrFloat() && !isa<ComplexType>(eltType))
1464 return {};
1465 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1466 return DenseElementsAttr::get(getType(), adaptor.getElements());
1467 return {};
1468}
1469
1470namespace {
1471
1472// Pushes the index_casts that occur before extractions to after the extract.
1473// This minimizes type conversion in some cases and enables the extract
1474// canonicalizer. This changes:
1475//
1476// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1477// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1478//
1479// to the following:
1480//
1481// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1482// %cast = arith.index_cast %extract : i32 to index
1483//
1484// to just %element.
1485//
1486// Consider expanding this to a template and handle all tensor cast
1487// operations.
1488struct ExtractElementFromIndexCast
1489 : public OpRewritePattern<tensor::ExtractOp> {
1490 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1491
1492 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1493 PatternRewriter &rewriter) const final {
1494 Location loc = extract.getLoc();
1495 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1496 if (!indexCast)
1497 return failure();
1498
1499 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1500
1501 auto newExtract = tensor::ExtractOp::create(
1502 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1503
1504 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1505 newExtract);
1506
1507 return success();
1508 }
1509};
1510
1511} // namespace
1512
1513void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1514 MLIRContext *context) {
1515 results.add<ExtractElementFromIndexCast>(context);
1516}
1517
1518//===----------------------------------------------------------------------===//
1519// GatherOp
1520//===----------------------------------------------------------------------===//
1521
1522void GatherOp::getAsmResultNames(
1523 function_ref<void(Value, StringRef)> setNameFn) {
1524 setNameFn(getResult(), "gather");
1525}
1526
1527/// Return the inferred result type for a gatherOp where:
1528/// - sourceType is the type of the source tensor gathered from
1529/// - indicesType is the type of the indices used to gather
1530/// - gatherDims are the dims along which the gather occurs.
1531/// Return a full rank or ranked-reduced variant of the type depending on
1532/// the value of rankReduced.
1533///
1534/// The leading dimensions of the index tensor give the result tensor its
1535/// leading dimensions.
1536/// The trailing dimensions of the result tensor are obtained from the source
1537/// tensor by setting the dimensions specified in gather_dims to `1` (if
1538/// rankedReduced is false), or skipping them (otherwise).
1539RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1540 RankedTensorType indicesType,
1541 ArrayRef<int64_t> gatherDims,
1542 bool rankReduced) {
1543 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1544 resultShape.reserve(resultShape.size() + sourceType.getRank());
1545 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1546 if (llvm::binary_search(gatherDims, idx)) {
1547 if (!rankReduced)
1548 resultShape.push_back(1);
1549 continue;
1550 }
1551 resultShape.push_back(sourceType.getDimSize(idx));
1552 }
1553 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1554}
1555
1556static LogicalResult
1559 StringRef gatherOrScatter, StringRef sourceOrDest) {
1560 if (dims.empty())
1561 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1562
1563 int64_t numGatherDims = dims.size();
1564 if (numGatherDims > rank)
1565 return op->emitOpError(gatherOrScatter)
1566 << "_dims overflow " << sourceOrDest << " rank";
1567 if (indices.empty() || indices.back() != numGatherDims)
1568 return op->emitOpError(gatherOrScatter)
1569 << "_dims length must match the size of last dimension of indices";
1570 for (int64_t val : dims) {
1571 if (val < 0)
1572 return op->emitOpError(gatherOrScatter)
1573 << "_dims value must be non-negative";
1574 if (val >= rank)
1575 return op->emitOpError(gatherOrScatter)
1576 << "_dims value must be smaller than " << sourceOrDest << " rank";
1577 }
1578 for (int64_t i = 1; i < numGatherDims; ++i) {
1579 if (dims[i - 1] >= dims[i])
1580 return op->emitOpError(gatherOrScatter)
1581 << "_dims values must be strictly increasing";
1582 }
1583 return success();
1584}
1585
1586LogicalResult GatherOp::verify() {
1587 int64_t sourceRank = getSourceType().getRank();
1588 ArrayRef<int64_t> gatherDims = getGatherDims();
1589 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1590 getIndicesType().getShape(), sourceRank,
1591 "gather", "source")))
1592 return failure();
1593
1594 RankedTensorType expectedResultType = GatherOp::inferResultType(
1595 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1596 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1597 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1598 if (getResultType() != expectedResultType &&
1599 getResultType() != expectedRankReducedResultType) {
1600 return emitOpError("result type "
1601 "mismatch: "
1602 "expected ")
1603 << expectedResultType << " or its rank-reduced variant "
1604 << expectedRankReducedResultType << " (got: " << getResultType()
1605 << ")";
1606 }
1607
1608 return success();
1609}
1610
1611OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1612 if (OpFoldResult reshapedSource = reshapeConstantSource(
1613 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1614 getResult().getType()))
1615 return reshapedSource;
1616 return {};
1617}
1618
1619//===----------------------------------------------------------------------===//
1620// InsertOp
1621//===----------------------------------------------------------------------===//
1622
1623void InsertOp::getAsmResultNames(
1624 function_ref<void(Value, StringRef)> setNameFn) {
1625 setNameFn(getResult(), "inserted");
1626}
1627
1628LogicalResult InsertOp::verify() {
1629 // Verify the # indices match if we have a ranked type.
1630 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1631 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1632 return emitOpError("incorrect number of indices");
1633 return success();
1634}
1635
1636OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1637 Attribute scalar = adaptor.getScalar();
1638 Attribute dest = adaptor.getDest();
1639 if (scalar && dest)
1640 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1641 if (scalar == splatDest.getSplatValue<Attribute>())
1642 return dest;
1643 return {};
1644}
1645
1646//===----------------------------------------------------------------------===//
1647// GenerateOp
1648//===----------------------------------------------------------------------===//
1649
1650void GenerateOp::getAsmResultNames(
1651 function_ref<void(Value, StringRef)> setNameFn) {
1652 setNameFn(getResult(), "generated");
1653}
1654
1655LogicalResult GenerateOp::reifyResultShapes(
1656 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1657 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1658 int idx = 0;
1659 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1660 if (getType().isDynamicDim(dim)) {
1661 reifiedReturnShapes[0][dim] = getOperand(idx++);
1662 } else {
1663 reifiedReturnShapes[0][dim] =
1664 builder.getIndexAttr(getType().getDimSize(dim));
1665 }
1666 }
1667 return success();
1668}
1669
1670LogicalResult GenerateOp::verify() {
1671 // Ensure that the tensor type has as many dynamic dimensions as are
1672 // specified by the operands.
1673 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1674 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1675 getOperands())))
1676 return failure();
1677 return success();
1678}
1679
1680LogicalResult GenerateOp::verifyRegions() {
1681 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1682 // Ensure that region arguments span the index space.
1683 if (!llvm::all_of(getBody().getArgumentTypes(),
1684 [](Type ty) { return ty.isIndex(); }))
1685 return emitError("all body arguments must be index");
1686 if (getBody().getNumArguments() != resultTy.getRank())
1687 return emitError("must have one body argument per input dimension");
1688
1689 // Ensure that the region yields an element of the right type.
1690 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1691
1692 if (yieldOp.getValue().getType() != resultTy.getElementType())
1693 return emitOpError(
1694 "body must be terminated with a `yield` operation of the tensor "
1695 "element type");
1696
1697 return success();
1698}
1699
1700void GenerateOp::build(
1701 OpBuilder &b, OperationState &result, Type resultTy,
1702 ValueRange dynamicExtents,
1703 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1704 build(b, result, resultTy, dynamicExtents);
1705
1706 // Build and populate body.
1707 OpBuilder::InsertionGuard guard(b);
1708 Region *bodyRegion = result.regions.front().get();
1709 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1710 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1711 SmallVector<Location, 2> argumentLocs(rank, result.location);
1712 Block *bodyBlock =
1713 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1714 bodyBuilder(b, result.location, bodyBlock->getArguments());
1715}
1716
1717namespace {
1718
1719/// Canonicalizes tensor.generate operations with a constant
1720/// operand into the equivalent operation with the operand expressed in the
1721/// result type, instead. We also insert a type cast to make sure that the
1722/// resulting IR is still well-typed.
1723struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1724 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1725
1726 LogicalResult matchAndRewrite(GenerateOp generateOp,
1727 PatternRewriter &rewriter) const final {
1728 SmallVector<Value> foldedDynamicSizes;
1729 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1730 generateOp.getType(), generateOp.getDynamicExtents(),
1731 foldedDynamicSizes);
1732
1733 // Stop here if no dynamic size was promoted to static.
1734 if (foldedTensorType == generateOp.getType())
1735 return failure();
1736
1737 auto loc = generateOp.getLoc();
1738 auto newOp =
1739 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1740 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1741 newOp.getBody().begin());
1742 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1743 generateOp.getType(), newOp);
1744 return success();
1745 }
1746};
1747
1748/// Canonicalizes the pattern of the form
1749///
1750/// %tensor = tensor.generate %x {
1751/// ^bb0(%arg0: index):
1752/// <computation>
1753/// yield %1 : index
1754/// } : tensor<?xindex>
1755/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1756///
1757/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1758/// tensor.generate operation has no side-effects.
1759struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1760 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1761
1762 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1763 PatternRewriter &rewriter) const final {
1764 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1765 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1766 return failure();
1767
1768 IRMapping mapping;
1769 Block *body = &tensorFromElements.getBody().front();
1770 mapping.map(body->getArguments(), extract.getIndices());
1771 for (auto &op : body->without_terminator())
1772 rewriter.clone(op, mapping);
1773
1774 auto yield = cast<YieldOp>(body->getTerminator());
1775
1776 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1777 return success();
1778 }
1779};
1780
1781} // namespace
1782
1783void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1784 MLIRContext *context) {
1785 // TODO: Move extract pattern to tensor::ExtractOp.
1786 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// RankOp
1791//===----------------------------------------------------------------------===//
1792
1793void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1794 setNameFn(getResult(), "rank");
1795}
1796
1797OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1798 // Constant fold rank when the rank of the operand is known.
1799 auto type = getOperand().getType();
1800 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1801 if (shapedType && shapedType.hasRank())
1802 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1803 return IntegerAttr();
1804}
1805
1806//===----------------------------------------------------------------------===//
1807// ReshapeOp
1808//===----------------------------------------------------------------------===//
1809
1810void ReshapeOp::getAsmResultNames(
1811 function_ref<void(Value, StringRef)> setNameFn) {
1812 setNameFn(getResult(), "reshape");
1813}
1814
1815static int64_t getNumElements(ShapedType type) {
1816 int64_t numElements = 1;
1817 for (auto dim : type.getShape())
1818 numElements *= dim;
1819 return numElements;
1820}
1821
1822LogicalResult ReshapeOp::verify() {
1823 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1824 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1825
1826 if (operandType.getElementType() != resultType.getElementType())
1827 return emitOpError("element types of source and destination tensor "
1828 "types should be the same");
1829
1830 int64_t shapeSize =
1831 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1832 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1833 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1834
1835 if (resultRankedType) {
1836 if (operandRankedType && resultRankedType.hasStaticShape() &&
1837 operandRankedType.hasStaticShape()) {
1838 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1839 return emitOpError("source and destination tensor should have the "
1840 "same number of elements");
1841 }
1842 if (ShapedType::isDynamic(shapeSize))
1843 return emitOpError("cannot use shape operand with dynamic length to "
1844 "reshape to statically-ranked tensor type");
1845 if (shapeSize != resultRankedType.getRank())
1846 return emitOpError(
1847 "length of shape operand differs from the result's tensor rank");
1848 }
1849 return success();
1850}
1851
1852OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1853 if (OpFoldResult reshapedSource = reshapeConstantSource(
1854 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1855 getResult().getType()))
1856 return reshapedSource;
1857
1858 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1859 // producer's input instead as the original tensor to reshape. This could
1860 // render such producer dead code.
1861 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1862 getSourceMutable().assign(reshapeOpProducer.getSource());
1863 return getResult();
1864 }
1865
1866 auto source = getSource();
1867 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1868 auto resultTy = dyn_cast<RankedTensorType>(getType());
1869 if (!sourceTy || !resultTy || sourceTy != resultTy)
1870 return {};
1871
1872 // If the source and result are both 0D or 1D tensors and have the same type,
1873 // the reshape has no effect, even if the tensor is dynamically shaped.
1874 if (sourceTy.getRank() <= 1)
1875 return source;
1876
1877 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1878 auto elements = fromElements.getElements();
1879 bool dynamicNoop =
1880 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1881 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1882 auto element = elements[id];
1883
1884 if (auto cst = getConstantIntValue(element)) {
1885 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1886 continue;
1887 }
1888
1889 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1890 dynamicNoop &= dimOp.getSource() == source;
1891
1892 auto cst = getConstantIntValue(dimOp.getIndex());
1893 dynamicNoop &=
1894 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1895 continue;
1896 }
1897
1898 dynamicNoop = false;
1899 break;
1900 }
1901
1902 if (dynamicNoop)
1903 return source;
1904 }
1905
1906 return {};
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// Reassociative reshape ops
1911//===----------------------------------------------------------------------===//
1912
1913void CollapseShapeOp::getAsmResultNames(
1914 function_ref<void(Value, StringRef)> setNameFn) {
1915 setNameFn(getResult(), "collapsed");
1916}
1917
1918void ExpandShapeOp::getAsmResultNames(
1919 function_ref<void(Value, StringRef)> setNameFn) {
1920 setNameFn(getResult(), "expanded");
1921}
1922
1923int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1924 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1925 "invalid resultDim");
1926 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1927 if (llvm::is_contained(it.value(), resultDim))
1928 return it.index();
1929 llvm_unreachable("could not find reassociation group");
1930}
1931
1932FailureOr<SmallVector<OpFoldResult>>
1933ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1934 RankedTensorType expandedType,
1935 ArrayRef<ReassociationIndices> reassociation,
1936 ArrayRef<OpFoldResult> inputShape) {
1937 std::optional<SmallVector<OpFoldResult>> outputShape =
1938 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1939 inputShape);
1940 if (!outputShape)
1941 return failure();
1942 return *outputShape;
1943}
1944
1945SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1946 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1947}
1948
1949void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1950 Type resultType, Value src,
1951 ArrayRef<ReassociationIndices> reassociation,
1952 ArrayRef<OpFoldResult> outputShape) {
1953 auto [staticOutputShape, dynamicOutputShape] =
1954 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1955 build(builder, result, cast<RankedTensorType>(resultType), src,
1956 getReassociationIndicesAttribute(builder, reassociation),
1957 dynamicOutputShape, staticOutputShape);
1958}
1959
1960void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1961 Type resultType, Value src,
1962 ArrayRef<ReassociationIndices> reassociation) {
1963 SmallVector<OpFoldResult> inputShape =
1964 getMixedSizes(builder, result.location, src);
1965 auto tensorResultTy = cast<RankedTensorType>(resultType);
1966 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1967 builder, result.location, tensorResultTy, reassociation, inputShape);
1968 SmallVector<OpFoldResult> outputShapeOrEmpty;
1969 if (succeeded(outputShape)) {
1970 outputShapeOrEmpty = *outputShape;
1971 }
1972 build(builder, result, tensorResultTy, src, reassociation,
1973 outputShapeOrEmpty);
1974}
1975
1976SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1977 return getSymbolLessAffineMaps(getReassociationExprs());
1978}
1979SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1981 getReassociationIndices());
1982}
1983
1984SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1985 return getSymbolLessAffineMaps(getReassociationExprs());
1986}
1987SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1989 getReassociationIndices());
1990}
1991
1992RankedTensorType CollapseShapeOp::inferCollapsedType(
1993 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1994 return inferCollapsedType(
1996 type.getContext(), reassociation)));
1997}
1998
1999/// Compute the RankedTensorType obtained by applying `reassociation` to
2000/// `type`.
2001RankedTensorType
2002CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2003 ArrayRef<AffineMap> reassociation) {
2004 auto shape = type.getShape();
2005 SmallVector<int64_t, 4> newShape;
2006 newShape.reserve(reassociation.size());
2007
2008 // Use the fact that reassociation is valid to simplify the logic: only use
2009 // each map's rank.
2010 assert(isReassociationValid(reassociation) && "invalid reassociation");
2011 unsigned currentDim = 0;
2012 for (AffineMap m : reassociation) {
2013 unsigned dim = m.getNumResults();
2014 auto band = shape.slice(currentDim, dim);
2015 int64_t size = 1;
2016 if (llvm::is_contained(band, ShapedType::kDynamic))
2017 size = ShapedType::kDynamic;
2018 else
2019 for (unsigned d = 0; d < dim; ++d)
2020 size *= shape[currentDim + d];
2021 newShape.push_back(size);
2022 currentDim += dim;
2023 }
2024
2025 return RankedTensorType::get(newShape, type.getElementType());
2026}
2027
2028void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2029 ArrayRef<ReassociationIndices> reassociation,
2030 ArrayRef<NamedAttribute> attrs) {
2031 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2032 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2033 auto resultType =
2034 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2035 srcType.getEncoding());
2036 result.addAttribute(getReassociationAttrStrName(),
2037 getReassociationIndicesAttribute(b, reassociation));
2038 build(b, result, resultType, src, attrs);
2039}
2040
2041template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2042 TensorReshapeOp, ExpandShapeOp>::value>
2043static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2044 RankedTensorType expandedType,
2045 RankedTensorType collapsedType) {
2046 if (failed(
2047 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2048 return failure();
2049
2050 // Reshape must preserve the number of elements when statically known.
2051 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2052 int64_t expandedNumElements = expandedType.getNumElements();
2053 int64_t collapsedNumElements = collapsedType.getNumElements();
2054 if (expandedNumElements != collapsedNumElements) {
2055 return op.emitOpError("number of elements must be preserved: ")
2056 << expandedNumElements << " != " << collapsedNumElements;
2057 }
2058 }
2059
2060 auto maps = op.getReassociationMaps();
2061 RankedTensorType expectedType =
2062 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2063 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2064 return op.emitOpError("expected collapsed type to be ")
2065 << expectedType << ", but got " << collapsedType;
2066 return success();
2067}
2068
2069LogicalResult ExpandShapeOp::verify() {
2070 RankedTensorType srcType = getSrc().getType();
2071 RankedTensorType resultType = getResult().getType();
2072
2073 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2074 return emitOpError("expected number of static shape dims to be equal to "
2075 "the output rank (")
2076 << resultType.getRank() << ") but found "
2077 << getStaticOutputShape().size() << " inputs instead";
2078
2079 if ((int64_t)getOutputShape().size() !=
2080 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2081 return emitOpError("mismatch in dynamic dims in output_shape and "
2082 "static_output_shape: static_output_shape has ")
2083 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2084 << " dynamic dims while output_shape has " << getOutputShape().size()
2085 << " values";
2086
2087 // Verify that the number of dynamic dims in output_shape matches the number
2088 // of dynamic dims in the result type.
2089 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2090 getOutputShape())))
2091 return failure();
2092
2093 // Verify if provided output shapes are in agreement with output type.
2094 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2095 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2096 for (auto [pos, shape] : llvm::enumerate(resShape))
2097 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2098 return emitOpError("invalid output shape provided at pos ") << pos;
2099
2100 return verifyTensorReshapeOp(*this, resultType, srcType);
2101}
2102
2103LogicalResult CollapseShapeOp::verify() {
2104 CollapseShapeOp op = *this;
2105 if (llvm::any_of(op.getReassociationIndices(),
2106 [](ReassociationIndices group) { return group.empty(); })) {
2107 return op.emitOpError("reassociation indices must not be empty");
2108 }
2109 RankedTensorType srcType = op.getSrc().getType();
2110 RankedTensorType resultType = op.getResult().getType();
2111
2112 return verifyTensorReshapeOp(op, srcType, resultType);
2113}
2114
2115namespace {
2116/// Reshape of a splat constant can be replaced with a constant of the result
2117/// type.
2118template <typename TensorReshapeOp>
2119struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter) const override {
2123 DenseElementsAttr attr;
2124 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2125 return failure();
2126 if (!attr || !attr.isSplat())
2127 return failure();
2128 // DenseElementsAttr requires a static shape; skip folding for dynamic
2129 // result types.
2130 if (!reshapeOp.getResultType().hasStaticShape())
2131 return failure();
2132 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2133 reshapeOp.getResultType(), attr.getRawData());
2134 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2135 return success();
2136 }
2137};
2138
2139// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2140template <typename TensorReshapeOp>
2141class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2142public:
2143 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2144
2145 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2146 PatternRewriter &rewriter) const override {
2147 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2148 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2149 return failure();
2150
2151 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2152 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2153 return success();
2154 }
2155};
2156
2157/// Reshape of a FromElements can be replaced with a FromElements of the
2158/// result type
2159template <typename TensorReshapeOp>
2160struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2161 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2162 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2163 PatternRewriter &rewriter) const override {
2164 auto fromElements =
2165 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2166 if (!fromElements)
2167 return failure();
2168
2169 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2170
2171 if (!shapedTy.hasStaticShape())
2172 return failure();
2173
2174 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2175 fromElements.getElements());
2176 return success();
2177 }
2178};
2179
2180// Fold CastOp into CollapseShapeOp when adding static information.
2181struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2182 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2183
2184 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2185 PatternRewriter &rewriter) const override {
2186 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2187 if (!tensor::canFoldIntoConsumerOp(castOp))
2188 return failure();
2189
2190 RankedTensorType srcType =
2191 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2192 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2193 srcType, collapseShapeOp.getReassociationMaps());
2194
2195 if (newResultType == collapseShapeOp.getResultType()) {
2196 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2197 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2198 });
2199 } else {
2200 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2201 newResultType, castOp.getSource(),
2202 collapseShapeOp.getReassociation());
2203 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2204 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2205 }
2206 return success();
2207 }
2208};
2209
2210/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2211/// matching constant output_shape operands of the expand. This makes the
2212/// `tensor.expand_shape` more static and creates a consumer cast that can be
2213/// propagated further.
2214struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2215 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2216
2217 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2218 PatternRewriter &rewriter) const override {
2219 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2220 if (!canFoldIntoConsumerOp(castOp))
2221 return failure();
2222
2223 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2224 SmallVector<ReassociationIndices, 4> reassoc =
2225 expandOp.getReassociationIndices();
2226
2227 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2228 SmallVector<Value> dynamicOutputShape;
2229 auto outputIt = expandOp.getOutputShape().begin();
2230
2231 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2232 for (uint64_t outDim : innerReassoc) {
2233 if (ShapedType::isStatic(newOutputShape[outDim]))
2234 continue;
2235
2236 // If the cast's src type is dynamic, don't infer any of the
2237 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2238 // least one of the expanded dimensions to be dynamic if the input is
2239 // dynamic.
2240 Value val = *outputIt;
2241 ++outputIt;
2242 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2243 dynamicOutputShape.push_back(val);
2244 continue;
2245 }
2246
2247 APInt cst;
2248 if (matchPattern(val, m_ConstantInt(&cst))) {
2249 newOutputShape[outDim] = cst.getSExtValue();
2250 } else {
2251 dynamicOutputShape.push_back(val);
2252 }
2253 }
2254 }
2255
2256 // Couldn't match any values, nothing to change
2257 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2258 return failure();
2259
2260 // Calculate the input shape from the output
2261 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2262 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2263 for (auto outDim : reassoc[inDim]) {
2264 auto ofr = newOutputShape[outDim];
2265 if (ShapedType::isDynamic(ofr)) {
2266 newInputShape[inDim] = ShapedType::kDynamic;
2267 break;
2268 }
2269 newInputShape[inDim] *= ofr;
2270 }
2271 }
2272
2273 SmallVector<OpFoldResult> outputOfr =
2274 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2275 auto inputType = RankedTensorType::get(
2276 newInputShape, expandOp.getSrcType().getElementType());
2277 auto outputType = RankedTensorType::get(
2278 newOutputShape, expandOp.getSrcType().getElementType());
2279 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2280 expandOp.getSrc());
2281 auto newExpand = ExpandShapeOp::create(
2282 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2283 expandOp.getReassociationIndices(), outputOfr);
2284 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2285 newExpand.getResult());
2286 return success();
2287 }
2288};
2289} // namespace
2290
2291void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2292 MLIRContext *context) {
2293 results.add<
2294 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2295 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2296 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2297 FoldReshapeWithSplat<ExpandShapeOp>,
2298 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2299}
2300
2301void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2302 MLIRContext *context) {
2303 results.add<
2304 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2305 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2306 tensor::DimOp, RankedTensorType>,
2307 FoldReshapeWithConstant<CollapseShapeOp>,
2308 FoldReshapeWithSplat<CollapseShapeOp>,
2309 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2310 context);
2311}
2312
2313OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2315 adaptor.getOperands());
2316}
2317
2318OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2320 adaptor.getOperands());
2321}
2322
2323//===----------------------------------------------------------------------===//
2324// ExtractSliceOp
2325//===----------------------------------------------------------------------===//
2326
2327void ExtractSliceOp::getAsmResultNames(
2328 function_ref<void(Value, StringRef)> setNameFn) {
2329 setNameFn(getResult(), "extracted_slice");
2330}
2331
2332/// An extract_slice result type can be inferred, when it is not
2333/// rank-reduced, from the source type and the static representation of
2334/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2335RankedTensorType
2336ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2337 ArrayRef<int64_t> staticSizes) {
2338 // An extract_slice op may specify only a leading subset of offset/sizes/
2339 // strides in which case we complete with offset=0, sizes from memref type
2340 // and strides=1.
2341 assert(static_cast<int64_t>(staticSizes.size()) ==
2342 sourceTensorType.getRank() &&
2343 "unexpected staticSizes not equal to rank of source");
2344 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2345 sourceTensorType.getEncoding());
2346}
2347
2348// TODO: This uses neither offsets nor strides!
2349RankedTensorType
2350ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2351 ArrayRef<OpFoldResult> sizes) {
2352 SmallVector<int64_t> staticSizes;
2353 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2354
2355 assert(static_cast<int64_t>(staticSizes.size()) ==
2356 sourceTensorType.getRank() &&
2357 "unexpected staticSizes not equal to rank of source");
2358 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2359 sourceTensorType.getEncoding());
2360}
2361
2362/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2363/// number of sizes), drop as many size 1 as needed to produce an inferred
2364/// type with the desired rank.
2365///
2366/// Note that there may be multiple ways to compute this rank-reduced type:
2367/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2368///
2369/// To disambiguate, this function always drops the first 1 sizes occurrences.
2370RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2371 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2372 ArrayRef<int64_t> sizes) {
2373 // Type inferred in the absence of rank-reducing behavior.
2374 auto inferredType = llvm::cast<RankedTensorType>(
2375 inferResultType(sourceRankedTensorType, sizes));
2376 int rankDiff = inferredType.getRank() - desiredResultRank;
2377 if (rankDiff > 0) {
2378 auto shape = inferredType.getShape();
2379 llvm::SmallBitVector dimsToProject =
2380 getPositionsOfShapeOne(rankDiff, shape);
2381 SmallVector<int64_t> projectedShape;
2382 // Best effort rank-reducing: drop 1s in order.
2383 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2384 if (!dimsToProject.test(pos))
2385 projectedShape.push_back(shape[pos]);
2386 inferredType =
2387 RankedTensorType::get(projectedShape, inferredType.getElementType());
2388 }
2389 return inferredType;
2390}
2391
2392RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2393 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2394 ArrayRef<OpFoldResult> sizes) {
2395 SmallVector<int64_t> staticSizes;
2396 SmallVector<Value> dynamicSizes;
2397 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2398 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2399 desiredResultRank, sourceRankedTensorType, staticSizes);
2400}
2401
2402/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2403/// result type. If the type passed is nullptr, it is inferred.
2404void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2405 RankedTensorType resultType, Value source,
2406 ArrayRef<OpFoldResult> offsets,
2407 ArrayRef<OpFoldResult> sizes,
2408 ArrayRef<OpFoldResult> strides,
2409 ArrayRef<NamedAttribute> attrs) {
2410 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2411 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2412 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2413 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2414 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2415 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2416 // Structuring implementation this way avoids duplication between builders.
2417 if (!resultType) {
2418 resultType = llvm::cast<RankedTensorType>(
2419 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2420 }
2421 result.addAttributes(attrs);
2422 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2423 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2424 b.getDenseI64ArrayAttr(staticSizes),
2425 b.getDenseI64ArrayAttr(staticStrides));
2426}
2427
2428/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2429/// result type.
2430void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2431 ArrayRef<OpFoldResult> offsets,
2432 ArrayRef<OpFoldResult> sizes,
2433 ArrayRef<OpFoldResult> strides,
2434 ArrayRef<NamedAttribute> attrs) {
2435 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2436}
2437
2438/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2439/// a Range vector.
2440void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2441 ArrayRef<Range> ranges,
2442 ArrayRef<NamedAttribute> attrs) {
2443 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2444 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2445}
2446
2447/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2448/// the type passed is nullptr, it is inferred.
2449void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2450 RankedTensorType resultType, Value source,
2451 ValueRange offsets, ValueRange sizes,
2452 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2453 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2454 offsets, [](Value v) -> OpFoldResult { return v; });
2455 SmallVector<OpFoldResult> sizeValues =
2456 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2457 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2458 strides, [](Value v) -> OpFoldResult { return v; });
2459 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2460}
2461
2462/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2463void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2464 ValueRange offsets, ValueRange sizes,
2465 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2466 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2467}
2468
2470 Operation *op,
2471 RankedTensorType expectedType) {
2472 switch (result) {
2474 return success();
2476 return op->emitError("expected rank to be smaller or equal to ")
2477 << "the other rank. ";
2479 return op->emitError("expected type to be ")
2480 << expectedType << " or a rank-reduced version. (size mismatch) ";
2482 return op->emitError("expected element type to be ")
2483 << expectedType.getElementType();
2484 default:
2485 llvm_unreachable("unexpected extract_slice op verification result");
2486 }
2487}
2488
2489/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2490/// result type, offsets set to 0 and strides set to 1.
2491void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2492 RankedTensorType resultType, Value source,
2493 ArrayRef<OpFoldResult> sizes,
2494 ArrayRef<NamedAttribute> attrs) {
2495 Attribute zeroIdxAttr = b.getIndexAttr(0);
2496 Attribute oneIdxAttr = b.getIndexAttr(1);
2497 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2498 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2499 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2500}
2501
2502/// Verifier for ExtractSliceOp.
2503LogicalResult ExtractSliceOp::verify() {
2504 RankedTensorType sourceType = getSourceType();
2505
2506 // Verify result type against inferred type.
2507 RankedTensorType expectedType =
2508 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2511 return produceSliceErrorMsg(result, *this, expectedType);
2512
2513 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2514 // to the source tensor.
2515 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2516 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2517 getStaticStrides(), /*generateErrorMessage=*/true);
2518 if (!boundsResult.isValid)
2519 return getOperation()->emitError(boundsResult.errorMessage);
2520
2521 return success();
2522}
2523
2524llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2525 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2526}
2527
2528FailureOr<Value>
2529ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2530 ArrayRef<int64_t> desiredShape) {
2531 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2532 assert(sourceTensorType && "not a ranked tensor type");
2533 auto sourceShape = sourceTensorType.getShape();
2534 if (sourceShape.equals(desiredShape))
2535 return value;
2536 auto maybeRankReductionMask =
2537 mlir::computeRankReductionMask(sourceShape, desiredShape);
2538 if (!maybeRankReductionMask)
2539 return failure();
2541 b, loc, value,
2542 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2543}
2544
2545LogicalResult ExtractSliceOp::reifyResultShapes(
2546 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2547 reifiedReturnShapes.resize(1);
2548 reifiedReturnShapes[0].reserve(getType().getRank());
2549 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2550 llvm::SmallBitVector droppedDims = getDroppedDims();
2551 for (const auto &size : enumerate(mixedSizes)) {
2552 if (droppedDims.test(size.index()))
2553 continue;
2554 reifiedReturnShapes[0].push_back(size.value());
2555 }
2556 return success();
2557}
2558
2559namespace {
2560/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2561/// This essentially pushes memref_cast past its consuming slice when
2562/// `canFoldIntoConsumerOp` is true.
2563///
2564/// Example:
2565/// ```
2566/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2567/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2568/// tensor<3x4xf32>
2569/// ```
2570/// is rewritten into:
2571/// ```
2572/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2573/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2574/// ```
2575class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2576public:
2577 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2578
2579 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2580 PatternRewriter &rewriter) const override {
2581 // Any constant operand, just return to let the constant folder kick in.
2582 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2583 return matchPattern(operand, matchConstantIndex());
2584 }))
2585 return failure();
2586
2587 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2588 if (!castOp)
2589 return failure();
2590
2591 if (!canFoldIntoConsumerOp(castOp))
2592 return failure();
2593
2594 // Pattern does not apply if the produced op would not verify.
2595 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2596 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2597 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2598 sliceOp.getStaticStrides());
2599 if (!sliceResult.isValid)
2600 return failure();
2601
2602 // Create folded extract.
2603 Location loc = sliceOp.getLoc();
2604 Value newResult = ExtractSliceOp::create(
2605 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2606 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2607 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2608 sliceOp.getStaticStrides());
2609 rewriter.replaceOp(sliceOp, newResult);
2610 return success();
2611 }
2612};
2613
2614/// Slice elements from `values` into `outValues`. `counts` represents the
2615/// numbers of elements to stride in the original values for each dimension.
2616/// The output values can be used to construct a DenseElementsAttr.
2617template <typename IterTy, typename ElemTy>
2618static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2619 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2620 ArrayRef<int64_t> strides,
2621 llvm::SmallVectorImpl<ElemTy> *outValues) {
2622 assert(offsets.size() == sizes.size());
2623 assert(offsets.size() == strides.size());
2624 if (offsets.empty())
2625 return;
2626
2627 int64_t offset = offsets.front();
2628 int64_t size = sizes.front();
2629 int64_t stride = strides.front();
2630 if (offsets.size() == 1) {
2631 for (int64_t i = 0; i < size; ++i, offset += stride)
2632 outValues->push_back(*(values + offset));
2633
2634 return;
2635 }
2636
2637 for (int64_t i = 0; i < size; ++i, offset += stride) {
2638 auto begin = values + offset * counts.front();
2639 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2640 offsets.drop_front(), sizes.drop_front(),
2641 strides.drop_front(), outValues);
2642 }
2643}
2644
2645/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2646/// folded operation might introduce more constant data; Users can control
2647/// their heuristics by the control function.
2648class ConstantOpExtractSliceFolder final
2649 : public OpRewritePattern<ExtractSliceOp> {
2650public:
2651 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2652
2653 ConstantOpExtractSliceFolder(MLIRContext *context,
2655 : OpRewritePattern<ExtractSliceOp>(context),
2656 controlFn(std::move(controlFn)) {}
2657
2658 LogicalResult matchAndRewrite(ExtractSliceOp op,
2659 PatternRewriter &rewriter) const override {
2660 DenseElementsAttr attr;
2661 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2662 return failure();
2663
2664 // A constant splat is handled by fold().
2665 if (attr.isSplat())
2666 return failure();
2667
2668 // Dynamic result shape is not supported.
2669 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2670 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2671 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2672 return failure();
2673
2674 // Customized control over the folding.
2675 if (!controlFn(op))
2676 return failure();
2677
2678 int64_t count = sourceType.getNumElements();
2679 if (count == 0)
2680 return failure();
2681
2682 // Check if there are any dynamic parts, which are not supported.
2683 auto offsets = op.getStaticOffsets();
2684 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2685 return failure();
2686 auto sizes = op.getStaticSizes();
2687 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2688 return failure();
2689 auto strides = op.getStaticStrides();
2690 if (llvm::is_contained(strides, ShapedType::kDynamic))
2691 return failure();
2692
2693 // Compute the stride for each dimension.
2694 SmallVector<int64_t> counts;
2695 ArrayRef<int64_t> shape = sourceType.getShape();
2696 counts.reserve(shape.size());
2697 for (int64_t v : shape) {
2698 count = count / v;
2699 counts.push_back(count);
2700 }
2701
2702 // Slice the elements and construct a new attribute.
2703 SmallVector<Attribute> outValues;
2704 outValues.reserve(resultType.getNumElements());
2705 sliceElements(attr.value_begin<Attribute>(), counts, offsets, sizes,
2706 strides, &outValues);
2707 auto newAttr = DenseElementsAttr::get(resultType, outValues);
2708 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2709 return success();
2710 }
2711
2712private:
2713 /// This additionally controls whether the fold happens or not. Users can
2714 /// impose their heuristics in the function.
2716};
2717
2718} // namespace
2719
2721 RewritePatternSet &patterns,
2722 const ControlConstantExtractSliceFusionFn &controlFn) {
2723 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2724}
2725
2726/// Return the canonical type of the result of an extract_slice op.
2728 RankedTensorType operator()(ExtractSliceOp op,
2729 ArrayRef<OpFoldResult> mixedOffsets,
2730 ArrayRef<OpFoldResult> mixedSizes,
2731 ArrayRef<OpFoldResult> mixedStrides) {
2732 // Infer a tensor type without taking into account any rank reductions.
2733 RankedTensorType nonReducedType =
2734 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2735
2736 // Directly return the non-rank reduced type if there are no dropped
2737 // dims.
2738 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2739 if (droppedDims.none())
2740 return nonReducedType;
2741
2742 // Build the reduced shape, preserving the original rank reduction pattern.
2743 SmallVector<int64_t> targetShape;
2744 for (auto i : llvm::seq<int64_t>(mixedSizes.size()))
2745 if (!droppedDims.test(i))
2746 targetShape.push_back(nonReducedType.getDimSize(i));
2747
2748 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2749 nonReducedType.getEncoding());
2750 }
2751};
2752
2753/// A canonicalizer wrapper to replace ExtractSliceOps.
2755 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2756 ExtractSliceOp newOp) {
2757 Value replacement = newOp.getResult();
2758 if (replacement.getType() != op.getType())
2759 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2760 replacement);
2761 rewriter.replaceOp(op, replacement);
2762 }
2763};
2764
2765void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2766 MLIRContext *context) {
2767 results.add<
2768 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2769 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2770 ExtractSliceOpCastFolder>(context);
2771}
2772
2773//
2774static LogicalResult
2775foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2776 ShapedType shapedType) {
2777 OpBuilder b(op.getContext());
2778 for (OpFoldResult ofr : op.getMixedOffsets())
2779 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2780 return failure();
2781 // Rank-reducing noops only need to inspect the leading dimensions:
2782 // llvm::zip is appropriate.
2783 auto shape = shapedType.getShape();
2784 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2785 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2786 return failure();
2787 for (OpFoldResult ofr : op.getMixedStrides())
2788 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2789 return failure();
2790 return success();
2791}
2792
2793/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2794/// slice, we can return the InsertSliceOp's source directly.
2795// TODO: This only checks the immediate producer; extend to go up the
2796// insert/extract chain if the slices are disjoint.
2797static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2798 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2799
2800 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2801 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2802 insertOp.isSameAs(extractOp, isSame))
2803 return insertOp.getSource();
2804
2805 return {};
2806}
2807
2808OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2809 if (OpFoldResult reshapedSource = reshapeConstantSource(
2810 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2811 getResult().getType()))
2812 return reshapedSource;
2813 if (getSourceType() == getType() &&
2815 return this->getSource();
2816 if (Value slice = foldExtractAfterInsertSlice(*this))
2817 return slice;
2818
2819 return OpFoldResult();
2820}
2821
2823 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2824 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2825 unsigned rank = rankedTensorType.getRank();
2826 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2828 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2829 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2830 offsets, sizes, strides);
2831}
2832
2833//===----------------------------------------------------------------------===//
2834// InsertSliceOp
2835//===----------------------------------------------------------------------===//
2836
2837void InsertSliceOp::getAsmResultNames(
2838 function_ref<void(Value, StringRef)> setNameFn) {
2839 setNameFn(getResult(), "inserted_slice");
2840}
2841
2842// Build a InsertSliceOp with mixed static and dynamic entries.
2843void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2844 Value dest, ArrayRef<OpFoldResult> offsets,
2846 ArrayRef<OpFoldResult> strides,
2848 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2849 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2850 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2851 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2852 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2853 result.addAttributes(attrs);
2854 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2855 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2856 b.getDenseI64ArrayAttr(staticSizes),
2857 b.getDenseI64ArrayAttr(staticStrides));
2858}
2859
2860/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2861/// Range vector.
2862void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2863 Value dest, ArrayRef<Range> ranges,
2864 ArrayRef<NamedAttribute> attrs) {
2865 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2866 build(b, result, source, dest, offsets, sizes, strides, attrs);
2867}
2868
2869// Build a InsertSliceOp with dynamic entries.
2870void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2871 Value dest, ValueRange offsets, ValueRange sizes,
2872 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2873 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2874 offsets, [](Value v) -> OpFoldResult { return v; });
2875 SmallVector<OpFoldResult> sizeValues =
2876 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2877 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2878 strides, [](Value v) -> OpFoldResult { return v; });
2879 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2880}
2881
2882/// Rank-reducing type verification for both InsertSliceOp and
2883/// ParallelInsertSliceOp.
2885 RankedTensorType srcType, RankedTensorType dstType,
2886 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2887 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2888 // insert_slice is the inverse of extract_slice, use the same type
2889 // inference.
2890 RankedTensorType expected =
2891 ExtractSliceOp::inferResultType(dstType, staticSizes);
2892 if (expectedType)
2893 *expectedType = expected;
2894 return isRankReducedType(expected, srcType);
2895}
2896
2897/// Verifier for InsertSliceOp.
2898LogicalResult InsertSliceOp::verify() {
2899 // Verify result type against inferred type.
2900 RankedTensorType expectedType;
2902 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2903 getStaticSizes(), getStaticStrides(), &expectedType);
2905 return produceSliceErrorMsg(result, *this, expectedType);
2906
2907 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2908 // to the destination tensor.
2909 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2910 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2911 getStaticStrides(), /*generateErrorMessage=*/true);
2912 if (!boundsResult.isValid)
2913 return getOperation()->emitError(boundsResult.errorMessage);
2914
2915 return success();
2916}
2917
2918/// If we have two consecutive InsertSliceOp writing to the same slice, we
2919/// can mutate the second InsertSliceOp's destination to the first one's.
2920///
2921/// Example:
2922///
2923/// ```mlir
2924/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2925/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2926/// ```
2927///
2928/// folds into:
2929///
2930/// ```mlir
2931/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2932/// ```
2933///
2934/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2935static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2936 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2937
2938 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2939 if (!prevInsertOp ||
2940 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2941 !prevInsertOp.isSameAs(insertOp, isSame))
2942 return failure();
2943
2944 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2945 return success();
2946}
2947
2948/// Folds round-trip extract/insert slice op pairs.
2949/// Example:
2950/// ```mlir
2951/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2952/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2953/// ```
2954/// can be folded into %val.
2955static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2956 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2957
2958 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2959 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2960 !extractOp.isSameAs(insertOp, isSame))
2961 return nullptr;
2962
2963 return extractOp.getSource();
2964}
2965
2966OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2967 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2968 getSourceType() == getType() &&
2970 return this->getSource();
2971 if (succeeded(foldInsertAfterInsertSlice(*this)))
2972 return getResult();
2973 if (auto result = foldInsertAfterExtractSlice(*this))
2974 return result;
2975 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2976 return getDest();
2977 return OpFoldResult();
2978}
2979
2980LogicalResult InsertSliceOp::reifyResultShapes(
2981 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2982 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2983 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2984 return success();
2985}
2986
2987namespace {
2988/// Pattern to rewrite a insert_slice op with constant arguments.
2989///
2990/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2991template <typename InsertOpTy>
2992class InsertSliceOpConstantArgumentFolder final
2993 : public OpRewritePattern<InsertOpTy> {
2994public:
2995 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2996
2997 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2998 PatternRewriter &rewriter) const override {
2999 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3000 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3001 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3002
3003 // No constant operands were folded, just return;
3004 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
3005 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
3006 failed(foldDynamicStrideList(mixedStrides)))
3007 return failure();
3008
3009 // Pattern does not apply if the produced op would not verify.
3010 SliceBoundsVerificationResult sliceResult =
3011 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
3012 mixedOffsets, mixedSizes, mixedStrides);
3013 if (!sliceResult.isValid)
3014 return failure();
3015
3016 // Create the new op in canonical form.
3017 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3018 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3019 mixedSizes);
3020 Value toInsert = insertSliceOp.getSource();
3021 if (sourceType != insertSliceOp.getSourceType()) {
3022 OpBuilder::InsertionGuard g(rewriter);
3023 // The only difference between InsertSliceOp and ParallelInsertSliceOp
3024 // is that the insertion point is just before the InParallelOp in
3025 // the parallel case.
3026 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3027 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3028 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3029 sourceType, toInsert);
3030 }
3031 rewriter.replaceOpWithNewOp<InsertOpTy>(
3032 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3033 mixedSizes, mixedStrides);
3034 return success();
3035 }
3036};
3037
3038/// Fold tensor_casts with insert_slice operations. If the source or
3039/// destination tensor is a tensor_cast that removes static type information,
3040/// the cast is folded into the insert_slice operation. E.g.:
3041///
3042/// ```mlir
3043/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3044/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3045/// ```
3046///
3047/// folds into:
3048///
3049/// ```mlir
3050/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3051/// ```
3052///
3053/// Note: When folding a cast on the destination tensor, the result of the
3054/// insert_slice operation is casted to ensure that the type of the result did
3055/// not change.
3056///
3057/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3058template <typename InsertOpTy>
3059struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3060 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3061
3062 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3063 PatternRewriter &rewriter) const override {
3064 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3065 return matchPattern(operand, matchConstantIndex());
3066 }))
3067 return failure();
3068
3069 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3070 auto castOp = v.getDefiningOp<tensor::CastOp>();
3071 if (!castOp || !canFoldIntoConsumerOp(castOp))
3072 return std::nullopt;
3073 return castOp.getSource();
3074 };
3075 std::optional<Value> sourceCastSource =
3076 getSourceOfCastOp(insertSliceOp.getSource());
3077 std::optional<Value> destCastSource =
3078 getSourceOfCastOp(insertSliceOp.getDest());
3079 if (!sourceCastSource && !destCastSource)
3080 return failure();
3081
3082 auto src =
3083 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3084 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3085 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3086 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3087 if (!srcType || !dstType)
3088 return failure();
3089
3090 // The tensor.cast source could have additional static information not seen
3091 // in the insert slice op static sizes, so we ignore dynamic dims when
3092 // computing the rank reduction mask.
3093 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3094 auto rankReductionMask = computeRankReductionMask(
3095 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3096 if (!rankReductionMask.has_value())
3097 return failure();
3098 // Replace dimensions in the insert slice op with corresponding static dims
3099 // from the cast source type. If the insert slice sizes have static dims
3100 // that are not static in the tensor.cast source (i.e., when the cast op
3101 // casts a dynamic dim to static), the dim should not be replaced, and the
3102 // pattern will fail later in `verifyInsertSliceOp`.
3103 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3104 int64_t rankReducedIdx = 0;
3105 for (auto [idx, size] : enumerate(staticSizes)) {
3106 if (!rankReductionMask.value().contains(idx) &&
3107 !srcType.isDynamicDim(rankReducedIdx)) {
3108 mixedSizes[idx] = getAsIndexOpFoldResult(
3109 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3110 size = srcType.getDimSize(rankReducedIdx++);
3111 }
3112 }
3113
3114 // Pattern does not apply if the produced op would not verify.
3115 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3116 staticSizes, insertSliceOp.getStaticStrides()) !=
3117 SliceVerificationResult::Success)
3118 return failure();
3119 SliceBoundsVerificationResult sliceResult =
3120 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3121 mixedSizes, insertSliceOp.getMixedStrides());
3122 if (!sliceResult.isValid)
3123 return failure();
3124
3125 Operation *replacement =
3126 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3127 insertSliceOp.getMixedOffsets(), mixedSizes,
3128 insertSliceOp.getMixedStrides());
3129
3130 // In the parallel case there is no result and so nothing to cast.
3131 bool isParallelInsert =
3132 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3133 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3134 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3135 insertSliceOp.getDestType(),
3136 replacement->getResult(0));
3137 }
3138 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3139 return success();
3140 }
3141};
3142
3143/// If additional static type information can be deduced from a insert_slice's
3144/// size operands, insert an explicit cast of the op's source operand. This
3145/// enables other canonicalization patterns that are matching for tensor_cast
3146/// ops such as `ForOpTensorCastFolder` in SCF.
3147///
3148/// Example:
3149///
3150/// ```mlir
3151/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3152/// : tensor<?x?xf32> into ...
3153/// ```
3154///
3155/// folds into:
3156///
3157/// ```mlir
3158/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3159/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3160/// : tensor<64x64xf32> into ...
3161/// ```
3162///
3163/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3164template <typename InsertOpTy>
3165struct InsertSliceOpSourceCastInserter final
3166 : public OpRewritePattern<InsertOpTy> {
3167 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3168
3169 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3170 PatternRewriter &rewriter) const override {
3171 RankedTensorType srcType = insertSliceOp.getSourceType();
3172 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3173 return failure();
3174 SmallVector<int64_t> newSrcShape(srcType.getShape());
3175 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3176 if (std::optional<int64_t> constInt =
3177 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3178 // Bail on invalid IR.
3179 if (*constInt < 0)
3180 return failure();
3181 newSrcShape[i] = *constInt;
3182 }
3183 }
3184 if (!hasValidSizesOffsets(newSrcShape))
3185 return failure();
3186
3187 RankedTensorType newSrcType = RankedTensorType::get(
3188 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3189 if (srcType == newSrcType ||
3190 !preservesStaticInformation(srcType, newSrcType) ||
3191 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3192 return failure();
3193
3194 // newSrcType is:
3195 // 1) Different from srcType.
3196 // 2) "More static" than srcType.
3197 // 3) Cast-compatible with srcType.
3198 // Insert the cast.
3199 OpBuilder::InsertionGuard g(rewriter);
3200 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3201 // that the insertion point is just before the InParallelOp in the
3202 // parallel case.
3203 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3204 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3205 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3206 newSrcType, insertSliceOp.getSource());
3207 rewriter.replaceOpWithNewOp<InsertOpTy>(
3208 insertSliceOp, cast, insertSliceOp.getDest(),
3209 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3210 insertSliceOp.getMixedStrides());
3211 return success();
3212 }
3213};
3214} // namespace
3215
3216llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3217 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3218}
3219
3220void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3221 MLIRContext *context) {
3222 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3223 InsertSliceOpCastFolder<InsertSliceOp>,
3224 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3225}
3226
3228 Location loc,
3229 Value tensor,
3230 Value dest) {
3231 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3232 unsigned rank = rankedTensorType.getRank();
3233 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3234 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3235 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3236 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3237 sizes, strides);
3238}
3239
3240//===----------------------------------------------------------------------===//
3241// PadOp
3242//===----------------------------------------------------------------------===//
3243
3244void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3245 setNameFn(getResult(), "padded");
3246}
3247
3248LogicalResult PadOp::verify() {
3249 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3250 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3251 auto expectedType =
3252 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3253 if (!expectedType) {
3254 return emitError("failed to infer expectedType from sourceType ")
3255 << sourceType << ", specified resultType is " << resultType;
3256 }
3257 if (resultType.getRank() != expectedType.getRank()) {
3258 return emitError("specified type ")
3259 << resultType << " does not match the inferred type "
3260 << expectedType;
3261 }
3262 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3263 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3264 continue;
3265 if (expectedType.isDynamicDim(i))
3266 continue;
3267 return emitError("specified type ")
3268 << resultType << " does not match the inferred type "
3269 << expectedType;
3270 }
3271
3272 return success();
3273}
3274
3275LogicalResult PadOp::verifyRegions() {
3276 auto &region = getRegion();
3277 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3278 Block &block = region.front();
3279 if (block.getNumArguments() != rank)
3280 return emitError("expected the block to have ") << rank << " arguments";
3281
3282 // Note: the number and type of yield values are checked in the YieldOp.
3283 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3284 if (!en.value().isIndex())
3285 return emitOpError("expected block argument ")
3286 << (en.index() + 1) << " to be an index";
3287 }
3288
3289 // Ensure that the region yields an element of the right type.
3290 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3291 if (yieldOp.getValue().getType() !=
3292 llvm::cast<ShapedType>(getType()).getElementType())
3293 return emitOpError("expected yield type to match shape element type");
3294
3295 return success();
3296}
3297
3298RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3299 ArrayRef<int64_t> staticLow,
3300 ArrayRef<int64_t> staticHigh,
3301 ArrayRef<int64_t> resultShape) {
3302 unsigned rank = sourceType.getRank();
3303 if (staticLow.size() != rank)
3304 return RankedTensorType();
3305 if (staticHigh.size() != rank)
3306 return RankedTensorType();
3307 if (!resultShape.empty() && resultShape.size() != rank)
3308 return RankedTensorType();
3309
3310 SmallVector<int64_t, 4> inferredShape;
3311 for (auto i : llvm::seq<unsigned>(0, rank)) {
3312 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3313 staticHigh[i] == ShapedType::kDynamic) {
3314 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3315 : resultShape[i]);
3316 } else {
3317 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3318 assert((resultShape.empty() || size == resultShape[i] ||
3319 resultShape[i] == ShapedType::kDynamic) &&
3320 "mismatch between inferred shape and result shape");
3321 inferredShape.push_back(size);
3322 }
3323 }
3324
3325 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3326}
3327
3328void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3329 Value source, ArrayRef<int64_t> staticLow,
3330 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3331 bool nofold, ArrayRef<NamedAttribute> attrs) {
3332 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3333 if (!resultType)
3334 resultType = inferResultType(sourceType, staticLow, staticHigh);
3335 result.addAttributes(attrs);
3336 build(b, result, resultType, source, low, high,
3337 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3338 nofold ? b.getUnitAttr() : UnitAttr());
3339}
3340
3341void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3342 Value source, ValueRange low, ValueRange high, bool nofold,
3343 ArrayRef<NamedAttribute> attrs) {
3344 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3345 unsigned rank = sourceType.getRank();
3346 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3347 build(b, result, resultType, source, staticVector, staticVector, low, high,
3348 nofold, attrs);
3349}
3350
3351void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3352 Value source, ArrayRef<OpFoldResult> low,
3353 ArrayRef<OpFoldResult> high, bool nofold,
3354 ArrayRef<NamedAttribute> attrs) {
3355 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3356 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3357 SmallVector<int64_t, 4> staticLow, staticHigh;
3358 // staticLow and staticHigh have full information of the padding config.
3359 // This will grow staticLow and staticHigh with 1 value. If the config is
3360 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3361 // value as well.
3362 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3363 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3364 if (!resultType) {
3365 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3366 }
3367 assert(llvm::isa<RankedTensorType>(resultType));
3368 result.addAttributes(attrs);
3369 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3370 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3371 nofold ? b.getUnitAttr() : UnitAttr());
3372}
3373
3374void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3375 Value source, ArrayRef<OpFoldResult> low,
3376 ArrayRef<OpFoldResult> high, Value constantPadValue,
3377 bool nofold, ArrayRef<NamedAttribute> attrs) {
3378 build(b, result, resultType, source, low, high, nofold, attrs);
3379
3380 // Add a region and a block to yield the pad value.
3381 Region *region = result.regions[0].get();
3382 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3383 Repeated<Type> blockArgTypes(sourceRank, b.getIndexType());
3384 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3385
3386 // `builder.createBlock` changes the insertion point within the block. Create
3387 // a guard to reset the insertion point of the builder after it is destroyed.
3388 OpBuilder::InsertionGuard guard(b);
3389 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3390 tensor::YieldOp::create(b, result.location, constantPadValue);
3391}
3392
3393llvm::SmallBitVector PadOp::getPaddedDims() {
3394 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3395 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3396 for (const auto &en : enumerate(paddingWidths))
3397 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3398 paddedDims.set(en.index());
3399 };
3400 extractPaddedDims(getMixedLowPad());
3401 extractPaddedDims(getMixedHighPad());
3402 return paddedDims;
3403}
3404
3405namespace {
3406// Folds tensor.pad when padding is static zeros and the attribute
3407// doesn't request otherwise.
3408struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3409 using OpRewritePattern<PadOp>::OpRewritePattern;
3410
3411 LogicalResult matchAndRewrite(PadOp padTensorOp,
3412 PatternRewriter &rewriter) const override {
3413 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3414 return failure();
3415 if (padTensorOp.getNofold())
3416 return failure();
3417 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3418 padTensorOp, padTensorOp.getResult().getType(),
3419 padTensorOp.getSource());
3420 return success();
3421 }
3422};
3423
3424// Fold CastOp into PadOp when adding static information.
3425struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3426 using OpRewritePattern<PadOp>::OpRewritePattern;
3427
3428 LogicalResult matchAndRewrite(PadOp padTensorOp,
3429 PatternRewriter &rewriter) const override {
3430 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3431 if (!tensor::canFoldIntoConsumerOp(castOp))
3432 return failure();
3433
3434 auto newResultType = PadOp::inferResultType(
3435 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3436 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3437 padTensorOp.getResultType().getShape());
3438
3439 if (newResultType == padTensorOp.getResultType()) {
3440 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3441 padTensorOp.getSourceMutable().assign(castOp.getSource());
3442 });
3443 } else {
3444 auto newOp = PadOp::create(
3445 rewriter, padTensorOp->getLoc(), newResultType,
3446 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3447 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3448 padTensorOp.getHigh(), padTensorOp.getNofold(),
3449 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3450 IRMapping mapper;
3451 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3452
3453 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3454 padTensorOp, padTensorOp.getResultType(), newOp);
3455 }
3456 return success();
3457 }
3458};
3459
3460// Fold CastOp using the result of PadOp back into the latter if it adds
3461// static information.
3462struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3463 using OpRewritePattern<PadOp>::OpRewritePattern;
3464
3465 LogicalResult matchAndRewrite(PadOp padTensorOp,
3466 PatternRewriter &rewriter) const override {
3467 if (!padTensorOp.getResult().hasOneUse())
3468 return failure();
3469 auto tensorCastOp =
3470 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3471 if (!tensorCastOp)
3472 return failure();
3473 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3474 tensorCastOp.getDest().getType()))
3475 return failure();
3476
3477 auto replacementOp = PadOp::create(
3478 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3479 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3480 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3481 padTensorOp.getHigh(), padTensorOp.getNofold(),
3482 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3483 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3484
3485 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3486 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3487 return success();
3488 }
3489};
3490
3491/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3492/// different dimensions. The pattern applies if the following preconditions
3493/// hold:
3494/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3495/// 2) the tensor::ExtractSliceOps have only unit-strides,
3496/// 3) the tensor::PadOps perform only high-padding,
3497/// 4) the tensor::PadOps have the same constant padding value,
3498/// 5) the tensor::PadOps do not have common padding dimensions,
3499/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3500/// zero-offset for every dimension.
3501/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3502/// the
3503/// padded source dimensions.
3504///
3505/// Example:
3506///
3507/// ```mlir
3508/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3509/// : tensor<64x64xf32> to tensor<?x64xf32>
3510/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3511/// } : tensor<?x64xf32> to tensor<8x64xf32>
3512/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3513/// : tensor<8x64xf32> to tensor<8x?xf32>
3514/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3515/// } : tensor<8x?xf32> to tensor<8x4xf32>
3516/// ```
3517///
3518/// folds into:
3519///
3520/// ```mlir
3521/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3522/// : tensor<64x64xf32> to tensor<?x?xf32>
3523/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3524/// } : tensor<?x?xf32> to tensor<8x4xf32>
3525/// ```
3526struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3527 using OpRewritePattern<PadOp>::OpRewritePattern;
3528
3529 LogicalResult matchAndRewrite(PadOp padOp,
3530 PatternRewriter &rewriter) const override {
3531 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3532 if (!innerSliceOp)
3533 return failure();
3534 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3535 if (!outerPadOp || outerPadOp.getNofold())
3536 return failure();
3537 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3538 if (!outerSliceOp)
3539 return failure();
3540
3541 // 1) Fail if the chain is rank-reducing.
3542 int64_t rank = padOp.getSourceType().getRank();
3543 if (outerSliceOp.getSourceType().getRank() != rank) {
3544 return rewriter.notifyMatchFailure(padOp,
3545 "cannot fold rank-reducing chain");
3546 }
3547
3548 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3549 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3550 return rewriter.notifyMatchFailure(
3551 padOp, "cannot fold non-unit stride ExtractSliceOps");
3552 }
3553
3554 // 3) Fail if the tensor::PadOps have non-zero low padding.
3555 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3556 return rewriter.notifyMatchFailure(padOp,
3557 "cannot fold PadOps with low padding");
3558 }
3559
3560 // 4) Fail if the tensor::PadOps padding values do not match.
3561 Attribute innerAttr, outerAttr;
3562 Value innerValue = padOp.getConstantPaddingValue();
3563 Value outerValue = outerPadOp.getConstantPaddingValue();
3564 if (!innerValue || !outerValue ||
3565 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3566 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3567 innerAttr != outerAttr) {
3568 return rewriter.notifyMatchFailure(
3569 padOp, "cannot fold PadOps with different padding values");
3570 }
3571
3572 // 5) Fail if a dimension is padded by both tensor::PadOps.
3573 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3574 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3575 if (innerDims.anyCommon(outerDims)) {
3576 return rewriter.notifyMatchFailure(
3577 padOp, "cannot fold PadOps with common padding dimensions");
3578 }
3579
3580 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3581 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3582 // for every dimension, and use the offset the other pair. Fail if no
3583 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3584 // exists.
3585 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3586 for (auto en : enumerate(newOffsets)) {
3587 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3588 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3589 if (!innerDims.test(en.index()) &&
3590 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3591 en.value() = outerOffset;
3592 continue;
3593 }
3594 if (!outerDims.test(en.index()) &&
3595 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3596 en.value() = innerOffset;
3597 continue;
3598 }
3599 return rewriter.notifyMatchFailure(
3600 padOp, "cannot find zero-offset and zero-padding pair");
3601 }
3602
3603 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3604 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3605 // outer tensor::PadOp and fail if the size of the inner
3606 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3607 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3608 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3609 for (auto en : enumerate(newSizes)) {
3610 if (!outerDims.test(en.index()))
3611 continue;
3612 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3613 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3614 assert(ShapedType::isStatic(sourceSize) &&
3615 "expected padded dimension to have a static size");
3616 if (getConstantIntValue(sliceSize) != sourceSize) {
3617 return rewriter.notifyMatchFailure(
3618 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3619 "match the size of the outer padding");
3620 }
3621 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3622 }
3623
3624 // Combine the high paddings of the two tensor::PadOps.
3625 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3626 for (auto en : enumerate(newHighPad)) {
3627 if (innerDims.test(en.index()))
3628 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3629 if (outerDims.test(en.index()))
3630 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3631 }
3632
3633 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3634 // the two paddings in one step.
3635 auto newSliceOp = ExtractSliceOp::create(
3636 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3637 newSizes, innerSliceOp.getMixedStrides());
3638 auto newPadOp = PadOp::create(
3639 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3640 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3641 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3642 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3643 newPadOp.getRegion().begin());
3644 rewriter.replaceOp(padOp, newPadOp.getResult());
3645 return success();
3646 }
3647};
3648
3649struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3650 using OpRewritePattern<PadOp>::OpRewritePattern;
3651
3652 LogicalResult matchAndRewrite(PadOp padTensorOp,
3653 PatternRewriter &rewriter) const override {
3654 Value input = padTensorOp.getSource();
3655 if (!llvm::isa<RankedTensorType>(input.getType()))
3656 return failure();
3657 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3658 auto inputRank = inputDims.size();
3659
3660 auto oldResultType =
3661 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3662 if (!oldResultType)
3663 return failure();
3664
3665 auto outputDims = oldResultType.getShape();
3666
3667 // Extract the static info from the high and low operands.
3668 SmallVector<int64_t> constOperandsLow;
3669 SmallVector<Value> newLows;
3670 for (auto operand : padTensorOp.getLow()) {
3671 APSInt intOp;
3672 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3673 constOperandsLow.push_back(ShapedType::kDynamic);
3674 newLows.push_back(operand);
3675 continue;
3676 }
3677 constOperandsLow.push_back(intOp.getExtValue());
3678 }
3679 SmallVector<int64_t> constOperandsHigh;
3680 SmallVector<Value> newHighs;
3681 for (auto operand : padTensorOp.getHigh()) {
3682 APSInt intOp;
3683 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3684 constOperandsHigh.push_back(ShapedType::kDynamic);
3685 newHighs.push_back(operand);
3686 continue;
3687 }
3688 constOperandsHigh.push_back(intOp.getExtValue());
3689 }
3690
3691 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3692 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3693
3694 // Verify the op is well-formed.
3695 if (inputDims.size() != outputDims.size() ||
3696 inputDims.size() != constLow.size() ||
3697 inputDims.size() != constHigh.size())
3698 return failure();
3699
3700 auto lowCount = 0;
3701 auto highCount = 0;
3702 for (size_t i = 0; i < inputRank; i++) {
3703 if (constLow[i] == ShapedType::kDynamic)
3704 constLow[i] = constOperandsLow[lowCount++];
3705 if (constHigh[i] == ShapedType::kDynamic)
3706 constHigh[i] = constOperandsHigh[highCount++];
3707 }
3708
3709 auto staticLow = ArrayRef<int64_t>(constLow);
3710 auto staticHigh = ArrayRef<int64_t>(constHigh);
3711
3712 // Calculate the output sizes with the static information.
3713 SmallVector<int64_t> newOutDims;
3714 for (size_t i = 0; i < inputRank; i++) {
3715 if (outputDims[i] == ShapedType::kDynamic) {
3716 newOutDims.push_back(
3717 (staticLow[i] == ShapedType::kDynamic ||
3718 staticHigh[i] == ShapedType::kDynamic ||
3719 inputDims[i] == ShapedType::kDynamic
3720 ? ShapedType::kDynamic
3721 : inputDims[i] + staticLow[i] + staticHigh[i]));
3722 } else {
3723 newOutDims.push_back(outputDims[i]);
3724 }
3725 }
3726
3727 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3728 llvm::all_of(newOutDims,
3729 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3730 return failure();
3731
3732 // Rewrite the op using the new static type.
3733 auto newResultType = RankedTensorType::get(
3734 newOutDims, padTensorOp.getType().getElementType());
3735 auto newOp = PadOp::create(
3736 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3737 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3738 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3739
3740 IRMapping mapper;
3741 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3742 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3743 newOp);
3744
3745 return success();
3746 }
3747};
3748
3749/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3750///
3751/// Example:
3752///
3753/// ```mlir
3754/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3755/// tensor.yield %val
3756/// } : tensor<1x2xf32> to tensor<2x5xf32>
3757/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3758/// tensor.yield %val
3759/// } : tensor<1x5xf32> to tensor<5x7xf32>
3760/// ```
3761///
3762/// folds into:
3763///
3764/// ```mlir
3765/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3766/// tensor.yield %val
3767/// } : tensor<1x2xf32> to tensor<5x7xf32>
3768/// ```
3769struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3770 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3771
3772 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3773 PatternRewriter &rewriter) const override {
3774 if (padOp.getNofold()) {
3775 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3776 }
3777
3778 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3779 if (!producerPad || producerPad.getNofold()) {
3780 return rewriter.notifyMatchFailure(
3781 padOp, "producer is not a foldable tensor.pad op");
3782 }
3783
3784 // Fail if the tensor::PadOps padding values do not match.
3785 Value consumerPadValue = padOp.getConstantPaddingValue();
3786 Value producerPadValue = producerPad.getConstantPaddingValue();
3787 if (!consumerPadValue || !producerPadValue ||
3788 consumerPadValue != producerPadValue) {
3789 return rewriter.notifyMatchFailure(
3790 padOp,
3791 "cannot fold PadOps with different or non-constant padding values");
3792 }
3793
3794 Location loc = padOp.getLoc();
3795 AffineExpr d0, d1;
3796 bindDims(rewriter.getContext(), d0, d1);
3797
3798 // Combine the low/high paddings of the two tensor::PadOps.
3799 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3800 ArrayRef<OpFoldResult> producerPaddings) {
3801 SmallVector<OpFoldResult> sumPaddings;
3802 for (auto [consumerIndex, producerIndex] :
3803 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3804 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3805 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3806 }
3807 return sumPaddings;
3808 };
3809
3810 SmallVector<OpFoldResult> newHighPad =
3811 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3812 SmallVector<OpFoldResult> newLowPad =
3813 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3814
3815 auto newPadOp = tensor::PadOp::create(
3816 rewriter, padOp.getLoc(), padOp.getResultType(),
3817 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3818 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3819 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3820 newPadOp.getRegion().begin());
3821 rewriter.replaceOp(padOp, newPadOp.getResult());
3822 return success();
3823 }
3824};
3825
3826} // namespace
3827
3828LogicalResult
3829PadOp::reifyResultShapes(OpBuilder &b,
3830 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3831 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3832 SmallVector<OpFoldResult> lp = getMixedLowPad();
3833 SmallVector<OpFoldResult> hp = getMixedHighPad();
3834 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3835 if (!getType().isDynamicDim(i)) {
3836 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3837 continue;
3838 }
3839 Location loc = getLoc();
3840 Value dim = b.createOrFold<tensor::DimOp>(
3841 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3842
3843 AffineExpr d0, d1, d2;
3844 bindDims(b.getContext(), d0, d1, d2);
3845 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3846 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3847 }
3848 return success();
3849}
3850
3851void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3852 MLIRContext *context) {
3853 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3854 FoldOrthogonalPaddings, FoldStaticPadding,
3855 FoldConsecutiveConstantPadding>(context);
3856}
3857
3858/// Return the padding value of the PadOp if it constant. In this context,
3859/// "constant" means an actual constant or "defined outside of the block".
3860///
3861/// Values are considered constant in three cases:
3862/// - A ConstantLike value.
3863/// - A basic block argument from a different block.
3864/// - A value defined outside of the block.
3865///
3866/// If the padding value is not constant, an empty Value is returned.
3867Value PadOp::getConstantPaddingValue() {
3868 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3869 if (!yieldOp)
3870 return {};
3871 Value padValue = yieldOp.getValue();
3872 // Check if yield value is a constant.
3873 if (matchPattern(padValue, m_Constant()))
3874 return padValue;
3875 // Check if yield value is defined inside the PadOp block.
3876 if (padValue.getParentBlock() == &getRegion().front())
3877 return {};
3878 // Else: Yield value defined outside of the PadOp block.
3879 return padValue;
3880}
3881
3882OpFoldResult PadOp::fold(FoldAdaptor) {
3883 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3884 !getNofold())
3885 return getSource();
3886 return {};
3887}
3888
3889//===----------------------------------------------------------------------===//
3890// ParallelInsertSliceOp
3891//===----------------------------------------------------------------------===//
3892
3893OpResult ParallelInsertSliceOp::getTiedOpResult() {
3894 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3895 for (const auto &it :
3896 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3897 Operation &nextOp = it.value();
3898 if (&nextOp == getOperation())
3899 return parallelCombiningParent.getParentResult(it.index());
3900 }
3901 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3902}
3903
3904// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3905void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3906 Value source, Value dest,
3907 ArrayRef<OpFoldResult> offsets,
3908 ArrayRef<OpFoldResult> sizes,
3909 ArrayRef<OpFoldResult> strides,
3910 ArrayRef<NamedAttribute> attrs) {
3911 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3912 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3913 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3914 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3915 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3916 result.addAttributes(attrs);
3917 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3918 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3919 b.getDenseI64ArrayAttr(staticSizes),
3920 b.getDenseI64ArrayAttr(staticStrides));
3921}
3922
3923/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3924/// packed into a Range vector.
3925void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3926 Value source, Value dest,
3927 ArrayRef<Range> ranges,
3928 ArrayRef<NamedAttribute> attrs) {
3929 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3930 build(b, result, source, dest, offsets, sizes, strides, attrs);
3931}
3932
3933// Build a ParallelInsertSliceOp with dynamic entries.
3934void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3935 Value source, Value dest, ValueRange offsets,
3936 ValueRange sizes, ValueRange strides,
3937 ArrayRef<NamedAttribute> attrs) {
3938 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3939 offsets, [](Value v) -> OpFoldResult { return v; });
3940 SmallVector<OpFoldResult> sizeValues =
3941 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3942 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3943 strides, [](Value v) -> OpFoldResult { return v; });
3944 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3945}
3946
3947// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3948// to 0, strides set to 1 and inferred result type.
3949void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3950 Value dest, ArrayRef<OpFoldResult> sizes,
3951 ArrayRef<NamedAttribute> attrs) {
3952 Attribute zeroIdxAttr = b.getIndexAttr(0);
3953 Attribute oneIdxAttr = b.getIndexAttr(1);
3954 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3955 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3956 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3957}
3958
3959LogicalResult ParallelInsertSliceOp::verify() {
3960 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3961 return this->emitError("expected InParallelOpInterface parent, got:")
3962 << *(getOperation()->getParentOp());
3963
3964 // Verify result type against inferred type.
3965 RankedTensorType expectedType;
3967 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3968 getStaticSizes(), getStaticStrides(), &expectedType);
3970 return produceSliceErrorMsg(result, *this, expectedType);
3971
3972 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3973 // to the destination tensor.
3974 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3975 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3976 getStaticStrides(), /*generateErrorMessage=*/true);
3977 if (!boundsResult.isValid)
3978 return getOperation()->emitError(boundsResult.errorMessage);
3979
3980 return success();
3981}
3982
3983void ParallelInsertSliceOp::getCanonicalizationPatterns(
3984 RewritePatternSet &results, MLIRContext *context) {
3985 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3986 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3987 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3988}
3989
3990llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3991 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3992}
3993
3994// ParallelCombiningOpInterface implementation.
3995MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3996 return getDestMutable();
3997}
3998
3999Operation *ParallelInsertSliceOp::getIteratingParent() {
4000 // Return the parent InParallelOpInterface's parent.
4001 if (auto combiningOp =
4002 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4003 return combiningOp->getParentOp();
4004 return nullptr;
4005}
4006
4007//===----------------------------------------------------------------------===//
4008// ScatterOp
4009//===----------------------------------------------------------------------===//
4010
4011void ScatterOp::getAsmResultNames(
4012 function_ref<void(Value, StringRef)> setNameFn) {
4013 setNameFn(getResult(), "scatter");
4014}
4015
4016LogicalResult ScatterOp::verify() {
4017 int64_t destRank = getDestType().getRank();
4018 ArrayRef<int64_t> scatterDims = getScatterDims();
4019 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
4020 getIndicesType().getShape(), destRank,
4021 "scatter", "dest")))
4022 return failure();
4023
4024 if (!getUnique())
4025 return emitOpError("requires 'unique' attribute to be set");
4026 // TODO: we could also check statically that there are fewer leading index
4027 // tensor dims than the dest dims. If this is not the case, the unique
4028 // attribute cannot be true.
4029
4030 // Use the GatherOp::inferResultType on the `dest` type and verify the
4031 // expected type matches the source type.
4032 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4033 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
4034 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4035 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
4036 if (getSourceType() != expectedSourceType &&
4037 getSourceType() != expectedRankReducedSourceType) {
4038 return emitOpError("source type "
4039 "mismatch: "
4040 "expected ")
4041 << expectedSourceType << " or its rank-reduced variant "
4042 << expectedRankReducedSourceType << " (got: " << getSourceType()
4043 << ")";
4044 }
4045
4046 return success();
4047}
4048
4049//===----------------------------------------------------------------------===//
4050// SplatOp
4051//===----------------------------------------------------------------------===//
4052
4053void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4054 Type aggregateType, ValueRange dynamicSizes) {
4055 build(builder, result, aggregateType, element, dynamicSizes);
4056}
4057
4058void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4059 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4060 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4061 build(builder, result, aggregateType, element, dynamicSizes);
4062}
4063
4064void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4065 ArrayRef<OpFoldResult> sizes) {
4066 SmallVector<int64_t> staticShape;
4067 SmallVector<Value> dynamicSizes;
4068 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4069 build(builder, result, element, staticShape, dynamicSizes);
4070}
4071
4072void SplatOp::getAsmResultNames(
4073 function_ref<void(Value, StringRef)> setNameFn) {
4074 setNameFn(getResult(), "splat");
4075}
4076
4077LogicalResult SplatOp::verify() {
4078 return verifyDynamicDimensionCount(getOperation(), getType(),
4079 getDynamicSizes());
4080}
4081
4082LogicalResult
4083SplatOp::reifyResultShapes(OpBuilder &builder,
4084 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4085 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4086 unsigned ctr = 0;
4087 for (int64_t i = 0; i < getType().getRank(); ++i) {
4088 if (getType().isDynamicDim(i)) {
4089 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4090 } else {
4091 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4092 }
4093 }
4094 return success();
4095}
4096
4097OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4098 auto constOperand = adaptor.getInput();
4099 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4100 return {};
4101
4102 // Do not fold if the splat is not statically shaped
4103 if (!getType().hasStaticShape())
4104 return {};
4105
4106 // SplatElementsAttr::get treats single value for second arg as being a
4107 // splat.
4108 return SplatElementsAttr::get(getType(), {constOperand});
4109}
4110
4111//===----------------------------------------------------------------------===//
4112// Common Canonicalizers and Folders.
4113//===----------------------------------------------------------------------===//
4114static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4115 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4116 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4117 // might need special handling of attached regions.
4118 if (isa<InsertSliceOp>(op.getOperation()) ||
4119 isa<LoopLikeOpInterface>(op.getOperation()))
4120 return false;
4121
4123}
4124
4125/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4126/// the `tensor.cast` has source that is more static than the consuming op.
4127///
4128/// Example:
4129/// ```mlir
4130/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4131/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4132/// ```
4133///
4134/// folds into:
4135///
4136/// ```mlir
4137/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4138/// ```
4139/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4140/// can add the pattern to their canonicalizers.
4142 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4144 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4145
4146 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4147 PatternRewriter &rewriter) const override {
4148
4149 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4150 // for that instead.
4151 if (!foldTensorCastPrecondition(op) ||
4152 isa<linalg::RelayoutOpInterface>(*op))
4153 return failure();
4154
4155 SmallVector<Type> newResultTypes(op->getResultTypes());
4156 SmallVector<Value> newOperands =
4157 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4158
4159 // Clone op
4160 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4161
4162 SmallVector<Value, 4> replacements;
4163 replacements.reserve(newOp->getNumResults());
4164 for (auto [oldResult, newResult] :
4165 llvm::zip(op->getResults(), newOp->getResults())) {
4166 if (newResult.getType() != oldResult.getType()) {
4167 replacements.push_back(tensor::CastOp::create(
4168 rewriter, op->getLoc(), oldResult.getType(), newResult));
4169 } else {
4170 replacements.push_back(newResult);
4171 }
4172 }
4173 rewriter.replaceOp(op, replacements);
4174
4175 return success();
4176 }
4177};
4178
4179//===----------------------------------------------------------------------===//
4180// TensorDialect
4181//===----------------------------------------------------------------------===//
4182
4183void TensorDialect::getCanonicalizationPatterns(
4184 RewritePatternSet &results) const {
4185 results.add<FoldTensorCastProducerOp>(getContext());
4186}
4187
4188//===----------------------------------------------------------------------===//
4189// TableGen'd op method definitions
4190//===----------------------------------------------------------------------===//
4191
4192#define GET_OP_CLASSES
4193#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:382
MLIRContext * getContext() const
Definition Builders.h:56
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:78
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:93
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.