MLIR 22.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/Builders.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
39#include <optional>
40
41using namespace mlir;
42using namespace mlir::tensor;
43
44/// Materialize a single constant operation from a given attribute value with
45/// the desired resultant type.
46Operation *TensorDialect::materializeConstant(OpBuilder &builder,
47 Attribute value, Type type,
48 Location loc) {
49 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 return op;
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
54 return nullptr;
55}
56
58 int64_t dim) {
59 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62
63 return builder.getIndexAttr(tensorType.getDimSize(dim));
64}
65
67 Location loc, Value value) {
68 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
71 result.push_back(getMixedSize(builder, loc, value, i));
72 return result;
73}
74
76 OpResult opResult) {
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78 assert(tensorType && "expected tensor type");
79
80 // If the op has a destination, it implements DestinationStyleOpInterface and
81 // we can query the destination operand from that interface.
82 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83 if (destOp)
84 return destOp.getTiedOpOperand(opResult)->get();
85
86 // Otherwise, create a new destination tensor with the same shape.
88 b.setInsertionPoint(opResult.getDefiningOp());
89
90 // Compute sizes.
92 if (!tensorType.hasStaticShape()) {
93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94 ReifiedRankedShapedTypeDims reifiedShapes;
95 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96 return failure();
97 mixedSizes = reifiedShapes[opResult.getResultNumber()];
98 } else {
99 // Static shape: Take static sizes directly.
100 for (int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(b.getIndexAttr(sz));
102 }
103
104 // Create empty tensor.
105 Value emptyTensor =
106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
107 return emptyTensor;
108}
109
111 Operation *op,
113 for (OpResult opResult : op->getResults()) {
114 if (llvm::isa<TensorType>(opResult.getType())) {
115 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116 if (failed(destination))
117 return failure();
118 result.push_back(*destination);
119 }
120 }
121 return success();
122}
123
125 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
129 return false;
130 }
131 return tp1 == tp2; // default implementation
132}
133
134/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135/// rank-extending tensor.insert_slice op.
136static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137 ArrayRef<OpFoldResult> mixedSizes) {
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
140
141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
143 // Rank-reduced dims must have a static unit dimension.
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
147
148 if (shapePos < 0) {
149 // There are no more dims in the reduced shape. All remaining sizes must
150 // be rank-reduced dims.
151 assert(isStaticUnitSize && "expected unit dim");
152 droppedDims.set(idx);
153 continue;
154 }
155
156 // Dim is preserved if the size is not a static 1.
157 if (!isStaticUnitSize) {
158 --shapePos;
159 continue;
160 }
161
162 // Dim is preserved if the reduced shape dim is also 1.
163 if (reducedShape[shapePos] == 1) {
164 --shapePos;
165 continue;
166 }
167
168 // Otherwise: Dim is dropped.
169 droppedDims.set(idx);
170 }
171
172 assert(shapePos < 0 && "dimension mismatch");
173 return droppedDims;
174}
175
176/// Given a ranked tensor type and a range of values that defines its dynamic
177/// dimension sizes, turn all dynamic sizes that have a constant value into
178/// static dimension sizes.
179static RankedTensorType
180foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181 SmallVector<Value> &foldedDynamicSizes) {
182 SmallVector<int64_t> staticShape(type.getShape());
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
185
186 // Compute new static and dynamic sizes.
187 unsigned ctr = 0;
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
191 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192 if (cst.has_value()) {
193 // Dynamic size must be non-negative.
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
196 continue;
197 }
198 staticShape[i] = *cst;
199 } else {
200 foldedDynamicSizes.push_back(dynamicSize);
201 }
202 }
203 }
204
205 return RankedTensorType::get(staticShape, type.getElementType(),
206 type.getEncoding());
207}
208
209//===----------------------------------------------------------------------===//
210// BitcastOp
211//===----------------------------------------------------------------------===//
212
213bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214 if (inputs.size() != 1 || outputs.size() != 1)
215 return false;
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
219 if (!aT || !bT)
220 return false;
221
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223 return false;
224
225 return succeeded(verifyCompatibleShape(aT, bT));
226}
227
228namespace {
229
230/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231/// operation.
232struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter) const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
240 return failure();
241
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
245 return success();
246 }
247};
248
249} // namespace
250
251void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252 MLIRContext *context) {
253 results.add<ChainedTensorBitcast>(context);
254}
255
256//===----------------------------------------------------------------------===//
257// CastOp
258//===----------------------------------------------------------------------===//
259
260void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261 setNameFn(getResult(), "cast");
262}
263
264/// Returns true if `target` is a ranked tensor type that preserves static
265/// information available in the `source` ranked tensor type.
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269
270 // Requires RankedTensorType.
271 if (!sourceType || !targetType)
272 return false;
273
274 // Requires same elemental type.
275 if (sourceType.getElementType() != targetType.getElementType())
276 return false;
277
278 // Requires same rank.
279 if (sourceType.getRank() != targetType.getRank())
280 return false;
281
282 // Requires same encoding.
283 if (sourceType.getEncoding() != targetType.getEncoding())
284 return false;
285
286 // If cast is towards more static sizes along any dimension, don't fold.
287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
290 return false;
291 }
292
293 return true;
294}
295
296/// Determines whether tensor::CastOp casts to a more dynamic version of the
297/// source tensor. This is useful to fold a tensor.cast into a consuming op and
298/// implement canonicalization patterns for ops in different dialects that may
299/// consume the results of tensor.cast operations. Such foldable tensor.cast
300/// operations are typically inserted as `slice` ops and are canonicalized,
301/// to preserve the type compatibility of their uses.
302///
303/// Returns true when all conditions are met:
304/// 1. source and result are ranked tensors with same element type and rank.
305/// 2. the tensor type has more static information than the result
306///
307/// Example:
308/// ```mlir
309/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310/// %2 = consumer %1 ... : tensor<?x?xf32> ...
311/// ```
312///
313/// folds into:
314///
315/// ```mlir
316/// %2 = consumer %0 ... : tensor<8x16xf32> ...
317/// ```
319 if (!castOp)
320 return false;
321
322 // Can fold if the source of cast has at least as much static information as
323 // its results.
324 return preservesStaticInformation(castOp.getType(),
325 castOp.getSource().getType());
326}
327
328/// Determines whether the tensor::CastOp casts to a more static version of the
329/// source tensor. This is useful to fold into a producing op and implement
330/// canonicalization patterns with the `tensor.cast` op as the root, but
331/// producer being from different dialects. Returns true when all conditions are
332/// met:
333/// 1. source and result and ranked tensors with same element type and rank.
334/// 2. the result type has more static information than the source.
335///
336/// Example:
337/// ```mlir
338/// %1 = producer ... : tensor<?x?xf32>
339/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340/// ```
341///
342/// can be canonicalized to :
343///
344/// ```mlir
345/// %2 = producer ... : tensor<8x16xf32>
346/// ```
347/// Not all ops might be canonicalizable this way, but for those that can be,
348/// this method provides a check that it is worth doing the canonicalization.
350 if (!castOp)
351 return false;
352 return preservesStaticInformation(castOp.getSource().getType(),
353 castOp.getType());
354}
355
357 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
358 if (llvm::isa<BlockArgument>(opOperand.get()))
359 return false;
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
362 });
363}
364
366 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
367 SmallVector<Value> newOperands;
368 newOperands.reserve(op->getNumOperands());
369
370 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
371
372 // Assumes that the result has dpsInits followed by nonDpsInits.
373 int64_t dpsInitIdx = 0;
374 for (OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
376 bool fold = canFoldIntoConsumerOp(tensorCastOp);
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
381 }
382 return newOperands;
383}
384
385/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
386/// that can be folded.
388 bool folded = false;
389 for (OpOperand &operand : op->getOpOperands()) {
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
391 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
392 operand.set(castOp.getOperand());
393 folded = true;
394 }
395 }
396 return success(folded);
397}
398
399bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
400 if (inputs.size() != 1 || outputs.size() != 1)
401 return false;
402 Type a = inputs.front(), b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(b);
405 if (!aT || !bT)
406 return false;
407
408 if (aT.getElementType() != bT.getElementType())
409 return false;
410
411 return succeeded(verifyCompatibleShape(aT, bT));
412}
413
414/// Compute a TensorType that has the joined shape knowledge of the two
415/// given TensorTypes. The element types need to match.
417 assert(one.getElementType() == two.getElementType());
418
419 if (!one.hasRank())
420 return two;
421 if (!two.hasRank())
422 return one;
423
424 int64_t rank = one.getRank();
425 if (rank != two.getRank())
426 return {};
427
429 join.reserve(rank);
430 for (int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
433 continue;
434 }
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
437 continue;
438 }
439 if (one.getDimSize(i) != two.getDimSize(i))
440 return {};
441 join.push_back(one.getDimSize(i));
442 }
443 return RankedTensorType::get(join, one.getElementType());
444}
445
446namespace {
447
448/// Replaces chains of two tensor.cast operations by a single tensor.cast
449/// operation if doing so does not remove runtime constraints.
450struct ChainedTensorCast : public OpRewritePattern<CastOp> {
451 using OpRewritePattern<CastOp>::OpRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter) const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
456
457 if (!tensorCastOperand)
458 return failure();
459
460 auto sourceType =
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
464
465 // We can remove the intermediate cast if joining all three produces the
466 // same result as just joining the source and result shapes.
467 auto firstJoin =
468 joinShapes(joinShapes(sourceType, intermediateType), resultType);
469
470 // The join might not exist if the cast sequence would fail at runtime.
471 if (!firstJoin)
472 return failure();
473
474 // The newJoin always exists if the above join exists, it might just contain
475 // less information. If so, we cannot drop the intermediate cast, as doing
476 // so would remove runtime checks.
477 auto newJoin = joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
479 return failure();
480
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
483 return success();
484 }
485};
486
487/// Fold tensor.cast into tesor.extract_slice producer.
488/// Example:
489/// ```
490/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
491/// tensor<128x512xf32> to tensor<?x512xf32>
492/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
493/// ```
494/// ->
495/// ```
496/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
497/// tensor<128x512xf32> to tensor<16x512xf32>
498/// ```
499struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
500 using OpRewritePattern<CastOp>::OpRewritePattern;
501
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter) const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
506
507 // Cannot fold cast to unranked tensor.
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
511 return failure();
512
513 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
516 .getShape())
517 return failure();
518
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
520 auto dimMask = computeRankReductionMask(
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
522 size_t dimIndex = 0;
523 for (size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
525 continue;
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
528 continue;
529 sizes[i] = rewriter.getIndexAttr(dim);
530 }
531
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
536 return success();
537 }
538};
539
540} // namespace
541
542void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
543 MLIRContext *context) {
544 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
545}
546
547//===----------------------------------------------------------------------===//
548// ConcatOp
549//===----------------------------------------------------------------------===//
550
551RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553 auto tensorTypes =
554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
555 int64_t concatRank = tensorTypes[0].getRank();
556
557 // The concatenation dim must be in the range [0, rank).
558 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
559
560 SmallVector<int64_t> sizes(concatRank);
561 for (int64_t i = 0, e = concatRank; i < e; ++i) {
562 if (i == dim)
563 continue;
564 SaturatedInteger size;
565 for (auto tensorType : tensorTypes)
566 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
567 sizes[i] = size.asInteger();
568 }
569 auto concatSize = SaturatedInteger::wrap(0);
570 for (auto tensorType : tensorTypes)
571 concatSize =
572 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
573 sizes[dim] = concatSize.asInteger();
574 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
575}
576
577void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
578 ValueRange inputs) {
579 FailureOr<RankedTensorType> resultType =
580 inferResultType(dim, inputs.getTypes());
581 assert(succeeded(resultType) && "failed to infer concatenation result type");
582 build(builder, result, *resultType, dim, inputs);
583}
584
585LogicalResult ConcatOp::verify() {
586 if (getInputs().size() < 1)
587 return emitOpError("requires at least one input");
588
590 for (auto input : getInputs())
591 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
592
593 RankedTensorType resultType = getResultType();
594 int64_t resultRank = getRank();
595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
596 return type.getRank() != resultRank;
597 }))
598 return emitOpError("rank of concatenated inputs must match result rank");
599
600 Type resultElementType = resultType.getElementType();
601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
602 return type.getElementType() != resultElementType;
603 }))
604 return emitOpError("inputs and result element type must match");
605
606 int64_t dim = getDim();
607 if (dim >= resultRank)
608 return emitOpError("concatenation dim must be less than the tensor rank");
609
610 SmallVector<int64_t> sizes(resultRank);
611 for (int64_t i = 0, e = resultRank; i < e; ++i) {
612 if (i == dim)
613 continue;
614 SaturatedInteger size;
615 for (auto tensorType : inputTypes) {
616 FailureOr<SaturatedInteger> maybeSize =
617 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
618 if (failed(maybeSize))
619 return emitOpError("static concatenation size mismatch along ")
620 << "non-concatenated dimension " << i;
621 size = *maybeSize;
622 }
623 sizes[i] = size.asInteger();
624 }
625 auto concatSize = SaturatedInteger::wrap(0);
626 for (auto tensorType : inputTypes)
627 concatSize =
628 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
629 sizes[dim] = concatSize.asInteger();
630 auto inferredResultType =
631 RankedTensorType::get(sizes, inputTypes[0].getElementType());
632
633 for (auto [inferredSize, actualSize] :
634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
636 ShapedType::isDynamic(actualSize);
637 if (!hasDynamic && inferredSize != actualSize)
638 return emitOpError("result type ")
639 << resultType << "does not match inferred shape "
640 << inferredResultType << " static sizes";
641 }
642
643 return success();
644}
645
646FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
647 size_t numInputs = getInputs().size();
648 uint64_t concatDim = getDim();
649
651 inputShapes.reserve(numInputs);
652 SmallVector<OpFoldResult> concatOffsets;
653 concatOffsets.reserve(numInputs);
654 SmallVector<OpFoldResult> outputShape;
655
656 AffineExpr addExpr =
657 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
658 OpFoldResult zero = builder.getIndexAttr(0);
659 Location loc = getLoc();
660 for (auto [index, input] : llvm::enumerate(getInputs())) {
661 SmallVector<OpFoldResult> inputShape =
662 tensor::getMixedSizes(builder, input.getLoc(), input);
663 if (index == 0) {
664 outputShape = inputShape;
665 concatOffsets.push_back(zero);
666 } else {
667 concatOffsets.push_back(outputShape[concatDim]);
668 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
669 builder, loc, addExpr,
670 {outputShape[concatDim], inputShape[concatDim]});
671 }
672 inputShapes.emplace_back(std::move(inputShape));
673 }
674
675 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
677
678 int64_t rank = getType().getRank();
679 OpFoldResult one = builder.getIndexAttr(1);
680 SmallVector<OpFoldResult> strides(rank, one);
681 SmallVector<OpFoldResult> offsets(rank, zero);
682 for (auto [index, input] : llvm::enumerate(getInputs())) {
683 offsets[concatDim] = concatOffsets[index];
684 auto insertSlice = tensor::InsertSliceOp::create(
685 builder, loc, input, replacement, offsets, inputShapes[index], strides);
686 replacement = insertSlice.getResult();
687 }
688 if (replacement.getType() != getType()) {
689 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
690 }
692}
693
694LogicalResult
695ConcatOp::reifyResultShapes(OpBuilder &builder,
696 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
697 ValueRange inputs = getInputs();
698 int64_t dim = getDim();
699 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
700
701 Value init = inputs[0];
702 int64_t rank = getType().getRank();
703
704 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
705
706 // Pre-populate the result sizes with as much static information as possible
707 // from the given result type, as well as the inferred result type, otherwise
708 // use the dim sizes from the first input.
709 for (int64_t i = 0; i < rank; ++i) {
710 if (i == dim)
711 continue;
712 if (!getType().isDynamicDim(i)) {
713 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
714 } else if (!inferredResultType.isDynamicDim(i)) {
715 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
716 builder, getLoc(),
717 builder.getIndexAttr(inferredResultType.getDimSize(i)));
718 } else {
719 reifiedReturnShapes[0][i] =
720 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
721 }
722 }
723
724 if (getType().isDynamicDim(dim)) {
725 // Take the sum of the input sizes along the concatenated dim.
726 AffineExpr sum = builder.getAffineDimExpr(0);
728 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
729 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
730 sum = sum + builder.getAffineDimExpr(idx + 1);
731 sizes.push_back(
732 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
733 }
734 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
735 builder, getLoc(),
736 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
737 } else {
738 // If the result shape is static along the concatenated dim, use the static
739 // shape.
740 reifiedReturnShapes[0][dim] =
741 builder.getIndexAttr(getType().getDimSize(dim));
742 }
743 return success();
744}
745
746void ConcatOp::getAsmResultNames(
747 function_ref<void(Value, StringRef)> setNameFn) {
748 setNameFn(getResult(), "concat");
749}
750
751OpFoldResult ConcatOp::fold(FoldAdaptor) {
752 ValueRange inputs = getInputs();
753 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
754 return inputs[0];
755 return {};
756}
757
758namespace {
759/// Fold a concat op with a single input to a cast.
760struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
761 using OpRewritePattern<ConcatOp>::OpRewritePattern;
762
763 LogicalResult matchAndRewrite(ConcatOp concatOp,
764 PatternRewriter &rewriter) const override {
765 if (concatOp.getInputs().size() != 1)
766 return failure();
767 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
768 concatOp.getInputs()[0]);
769 return success();
770 }
771};
772
773/// Propagate static shapes into the operands of a `tensor.concat`.
774///
775/// `tensor.concat` requires every operand to match on all dimensions except the
776/// concatenation dimension. If one operand is already static in those
777/// dimensions, the other operands may safely be refined to that same static
778/// shape.
779///
780/// Example:
781///
782/// ```mlir
783/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
784/// tensor<?x12xi32>
785/// ```
786/// ->
787/// ```mlir
788/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
789/// %2 = tensor.concat dim(0) %0, %cast :
790/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
791/// ```
792struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
793 using OpRewritePattern<ConcatOp>::OpRewritePattern;
794
795 LogicalResult matchAndRewrite(ConcatOp concatOp,
796 PatternRewriter &rewriter) const override {
797 int64_t dim = concatOp.getDim();
798 RankedTensorType inferredResultType =
799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
800
801 // Find operands for which a more static shape can be inferred.
802 LogicalResult matched = failure();
803 // Inferred operand shapes are identical in every dimension except the
804 // concatenation dimension.
805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
806 for (auto [operandIdx, operandType] :
807 llvm::enumerate(concatOp->getOperandTypes())) {
808 // Compute inferred type for operand.
809 inferredOperandShape[dim] =
810 cast<RankedTensorType>(operandType).getDimSize(dim);
811 auto inferredOperandType = RankedTensorType::get(
812 inferredOperandShape, inferredResultType.getElementType());
813
814 // Check if inferred type is more static.
815 if (!preservesStaticInformation(inferredOperandType, operandType)) {
816 matched = success();
817
818 // Use refined operand type and create cast from original operand.
819 auto castOp =
820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
821 concatOp.getOperand(operandIdx));
822 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
823 concatOp->setOperand(operandIdx, castOp->getResult(0));
824 });
825 }
826 }
827
828 return matched;
829 }
830};
831
832// Ensure `tensor.concat`'s result type is at least as static as can be inferred
833// from its operand types.
834///
835/// Example:
836/// ```mlir
837/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
838/// tensor<?x?xi32>
839/// ```
840/// ->
841/// ```mlir
842/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
843/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
844/// tensor<?x?xi32>
845/// ```
846struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
847 using OpRewritePattern<ConcatOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(ConcatOp concatOp,
850 PatternRewriter &rewriter) const override {
851 int64_t dim = concatOp.getDim();
852 RankedTensorType inferredResultType =
853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
854
855 // The result type should be at least as static as inferred result type.
856 if (preservesStaticInformation(inferredResultType,
857 concatOp.getResultType())) {
858 return failure();
859 }
860
861 auto newConcatOp =
862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
863 concatOp->getOperands());
864 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
865 newConcatOp);
866
867 return success();
868 }
869};
870} // namespace
871
872void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
873 MLIRContext *context) {
874 results
875 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
876 context);
877}
878
879//===----------------------------------------------------------------------===//
880// DimOp
881//===----------------------------------------------------------------------===//
882
883void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
884 setNameFn(getResult(), "dim");
885}
886
887void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
888 int64_t index) {
889 auto loc = result.location;
890 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
891 build(builder, result, source, indexValue);
892}
893
894std::optional<int64_t> DimOp::getConstantIndex() {
896}
897
898Speculation::Speculatability DimOp::getSpeculatability() {
899 auto constantIndex = getConstantIndex();
900 if (!constantIndex)
902
903 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
904 if (!rankedSourceType)
906
907 if (rankedSourceType.getRank() <= constantIndex)
909
911}
912
913void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
914 SetIntLatticeFn setResultRange) {
915 setResultRange(getResult(),
916 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
917}
918
919OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
920 // All forms of folding require a known index.
921 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
922 if (!index)
923 return {};
924
925 // Folding for unranked types (UnrankedTensorType) is not supported.
926 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
927 if (!tensorType)
928 return {};
929
930 // Out of bound indices produce undefined behavior but are still valid IR.
931 // Don't choke on them.
932 int64_t indexVal = index.getInt();
933 if (indexVal < 0 || indexVal >= tensorType.getRank())
934 return {};
935
936 // Fold if the shape extent along the given index is known.
937 if (!tensorType.isDynamicDim(index.getInt())) {
938 Builder builder(getContext());
939 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
940 }
941
942 Operation *definingOp = getSource().getDefiningOp();
943
944 // Fold dim to the operand of tensor.generate.
945 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
946 auto resultType =
947 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
948 // The case where the type encodes the size of the dimension is handled
949 // above.
950 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
951
952 // Find the operand of the fromElements that corresponds to this index.
953 auto dynExtents = fromElements.getDynamicExtents().begin();
954 for (auto dim : resultType.getShape().take_front(index.getInt()))
955 if (ShapedType::isDynamic(dim))
956 dynExtents++;
957
958 return Value{*dynExtents};
959 }
960
961 // The size at the given index is now known to be a dynamic size.
962 unsigned unsignedIndex = index.getValue().getZExtValue();
963
964 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
965 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
966 // `resolve-shaped-type-result-dims` pass.
967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
968 sliceOp.isDynamicSize(unsignedIndex)) {
969 return {sliceOp.getDynamicSize(unsignedIndex)};
970 }
971 }
972
973 // dim(cast) -> dim
974 if (succeeded(foldTensorCast(*this)))
975 return getResult();
976
977 return {};
978}
979
980namespace {
981/// Fold dim of a cast into the dim of the source of the tensor cast.
982struct DimOfCastOp : public OpRewritePattern<DimOp> {
983 using OpRewritePattern<DimOp>::OpRewritePattern;
984
985 LogicalResult matchAndRewrite(DimOp dimOp,
986 PatternRewriter &rewriter) const override {
987 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
988 if (!castOp)
989 return failure();
990 Value newSource = castOp.getOperand();
991 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
992 return success();
993 }
994};
995
996/// Fold dim of a destination passing style op into the dim of the corresponding
997/// init.
998struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
999 using OpRewritePattern<DimOp>::OpRewritePattern;
1000
1001 LogicalResult matchAndRewrite(DimOp dimOp,
1002 PatternRewriter &rewriter) const override {
1003 auto source = dimOp.getSource();
1004 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1005 if (!destOp)
1006 return failure();
1007
1008 auto resultIndex = cast<OpResult>(source).getResultNumber();
1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1010
1011 rewriter.modifyOpInPlace(
1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1013 return success();
1014 }
1015};
1016
1017/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1018/// operand.
1019struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1020 using OpRewritePattern<DimOp>::OpRewritePattern;
1021
1022 LogicalResult matchAndRewrite(DimOp dim,
1023 PatternRewriter &rewriter) const override {
1024 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1025
1026 if (!reshape)
1027 return failure();
1028
1029 // Since tensors are immutable we don't need to worry about where to place
1030 // the extract call
1031 rewriter.setInsertionPointAfter(dim);
1032 Location loc = dim.getLoc();
1033 Value extract =
1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1035 if (extract.getType() != dim.getType())
1036 extract =
1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1038 rewriter.replaceOp(dim, extract);
1039 return success();
1040 }
1041};
1042} // namespace
1043
1044void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1045 MLIRContext *context) {
1046 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// EmptyOp
1051//===----------------------------------------------------------------------===//
1052
1053void EmptyOp::build(OpBuilder &builder, OperationState &result,
1054 ArrayRef<int64_t> staticShape, Type elementType,
1055 Attribute encoding) {
1056 assert(none_of(staticShape, ShapedType::isDynamic) &&
1057 "expected only static sizes");
1058 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1059}
1060
1061void EmptyOp::build(OpBuilder &builder, OperationState &result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 ValueRange dynamicSizes, Attribute encoding) {
1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1065 build(builder, result, tensorType, dynamicSizes);
1066}
1067
1068void EmptyOp::build(OpBuilder &builder, OperationState &result,
1069 ArrayRef<OpFoldResult> sizes, Type elementType,
1070 Attribute encoding) {
1071 SmallVector<int64_t> staticShape;
1072 SmallVector<Value> dynamicSizes;
1073 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1074 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1075}
1076
1077LogicalResult EmptyOp::verify() {
1078 if (getType().getNumDynamicDims() != getDynamicSizes().size())
1079 return emitOpError("incorrect number of dynamic sizes, has ")
1080 << getDynamicSizes().size() << ", expected "
1081 << getType().getNumDynamicDims();
1082 return success();
1083}
1084
1085LogicalResult
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1087 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1089 unsigned ctr = 0;
1090 for (int64_t i = 0; i < getType().getRank(); ++i) {
1091 if (getType().isDynamicDim(i)) {
1092 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1093 } else {
1094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1095 }
1096 }
1097 return success();
1098}
1099
1100Value EmptyOp::getDynamicSize(unsigned idx) {
1101 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1102 unsigned ctr = 0;
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (getType().isDynamicDim(i))
1105 ++ctr;
1106 return getDynamicSizes()[ctr];
1107}
1108
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult> result;
1111 unsigned ctr = 0;
1112 OpBuilder b(getContext());
1113 for (int64_t i = 0; i < getType().getRank(); ++i) {
1114 if (getType().isDynamicDim(i)) {
1115 result.push_back(getDynamicSizes()[ctr++]);
1116 } else {
1117 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1118 }
1119 }
1120 return result;
1121}
1122
1123namespace {
1124/// Change the type of the result of a `tensor.empty` by making the result
1125/// type statically sized along dimensions that in the original operation were
1126/// defined as dynamic, but the size was defined using a `constant` op. For
1127/// example
1128///
1129/// %c5 = arith.constant 5: index
1130/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1131///
1132/// to
1133///
1134/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter) const override {
1140 SmallVector<Value> foldedDynamicSizes;
1141 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1143
1144 // Stop here if no dynamic size was promoted to static.
1145 if (foldedTensorType == op.getType())
1146 return failure();
1147
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1150 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1151 return success();
1152 }
1153};
1154
1155struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1157
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter) const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1163 return failure();
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1168 return failure();
1169 rewriter.replaceOp(dimOp,
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1171 return success();
1172 }
1173};
1174
1175/// Canonicalize
1176///
1177/// ```mlir
1178/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1179/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1180/// ```
1181///
1182/// into
1183///
1184/// ```mlir
1185/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1186/// ```
1187///
1188/// This assumes the input program is correct in terms of its shape. So it is
1189/// safe to assume that `%d0` is in fact 4.
1190struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1192
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter) const override {
1195 if (!canFoldIntoProducerOp(castOp))
1196 return failure();
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1198 if (!producer)
1199 return failure();
1200
1201 auto resultType =
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1210 int64_t newDim = std::get<0>(it);
1211 OpFoldResult currDim = std::get<1>(it);
1212 // Case 1: The empty tensor dim is static. Check that the tensor cast
1213 // result dim matches.
1214 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1215 if (ShapedType::isDynamic(newDim) ||
1216 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1217 // Something is off, the cast result shape cannot be more dynamic
1218 // than the empty tensor result shape (enforced by
1219 // `canFoldIntoProducer`). Abort for now.
1220 return rewriter.notifyMatchFailure(
1221 producer, "mismatch in static value of shape of empty tensor "
1222 "result and cast result");
1223 }
1224 newMixedSizes.push_back(attr);
1225 continue;
1226 }
1227
1228 // Case 2 : The tensor cast shape is static, but empty tensor result
1229 // shape is dynamic.
1230 if (ShapedType::isStatic(newDim)) {
1231 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1232 continue;
1233 }
1234
1235 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1236 // shape is dynamic. Use the dynamic value from the empty tensor op.
1237 newMixedSizes.push_back(currDim);
1238 }
1239
1240 // TODO: Do not drop tensor encoding.
1241 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1242 resultType.getElementType());
1243 return success();
1244 }
1245};
1246
1247} // namespace
1248
1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1252 ReplaceEmptyTensorStaticShapeDims>(context);
1253}
1254
1255//===----------------------------------------------------------------------===//
1256// ExtractOp
1257//===----------------------------------------------------------------------===//
1258
1259namespace {
1260
1261/// Canonicalizes the pattern of the form
1262///
1263/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1264/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1265///
1266/// to
1267///
1268/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1269struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1270 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1271
1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1273 PatternRewriter &rewriter) const final {
1274 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1275 if (!tensorCast)
1276 return failure();
1277 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1278 return failure();
1279 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1280 extract, tensorCast.getSource(), extract.getIndices());
1281 return success();
1282 }
1283};
1284
1285/// Canonicalizes the pattern of the form
1286///
1287/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1288/// tensor<12xf64>
1289/// %extracted_element = tensor.extract %val[%c10] :
1290/// tensor<12xf64>
1291///
1292/// to
1293///
1294/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1295struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1296 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1297
1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1299 PatternRewriter &rewriter) const final {
1300 auto collapseOp =
1301 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1302 if (!collapseOp)
1303 return failure();
1304 if (!collapseOp.getSrcType().hasStaticShape())
1305 return failure();
1306
1307 auto sourceSizes = collapseOp.getSrcType().getShape();
1308
1309 SmallVector<Value> indices(extractOp.getIndices().begin(),
1310 extractOp.getIndices().end());
1311 SmallVector<Value> sourceIndices;
1312 for (auto [index, group] :
1313 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1314 assert(!group.empty() && "association indices groups cannot be empty");
1315 auto groupSize = group.size();
1316
1317 if (groupSize == 1) {
1318 sourceIndices.push_back(index);
1319 continue;
1320 }
1321
1322 SmallVector<int64_t> basis =
1323 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1325 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1326 llvm::append_range(sourceIndices, delinearize.getResults());
1327 }
1328 if (collapseOp.getReassociationIndices().empty()) {
1329 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1330 int64_t srcRank =
1331 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1332 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1333 rewriter, extractOp.getLoc(), zeroAffineMap,
1334 ArrayRef<OpFoldResult>{});
1335 for (int64_t i = 0; i < srcRank; i++) {
1336 sourceIndices.push_back(
1337 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1338 }
1339 }
1340
1341 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1342 extractOp, collapseOp.getSrc(), sourceIndices);
1343 return success();
1344 }
1345};
1346
1347} // namespace
1348
1349void ExtractOp::getAsmResultNames(
1350 function_ref<void(Value, StringRef)> setNameFn) {
1351 setNameFn(getResult(), "extracted");
1352}
1353
1354LogicalResult ExtractOp::verify() {
1355 // Verify the # indices match if we have a ranked type.
1356 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1357 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1358 return emitOpError("incorrect number of indices for extract_element");
1359 return success();
1360}
1361
1362/// If we have an ExtractOp consuming an InsertOp with the same
1363/// indices, we can return the InsertOp's scalar directly.
1364// TODO: This only checks the immediate producer; extend to go up the
1365// insert/extract chain if the slices are disjoint.
1366static Value foldExtractAfterInsert(ExtractOp extractOp) {
1367 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1368
1369 auto isSame = [](Value a, Value b) {
1371 };
1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1374 return insertOp.getScalar();
1375
1376 return {};
1377}
1378
1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1380 if (Attribute tensor = adaptor.getTensor()) {
1381 // If this is a splat elements attribute, simply return the value.
1382 // All of the elements of a splat attribute are the same.
1383 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1384 return splatTensor.getSplatValue<Attribute>();
1385
1386 // If this is a dense resource elements attribute, return.
1387 if (isa<DenseResourceElementsAttr>(tensor))
1388 return {};
1389 }
1390
1391 // Collect the constant indices into the tensor.
1392 SmallVector<uint64_t, 8> indices;
1393 for (Attribute indice : adaptor.getIndices()) {
1394 if (!indice || !llvm::isa<IntegerAttr>(indice))
1395 return {};
1396 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1397 }
1398
1399 // Fold extract(from_elements(...)).
1400 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1401 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1402 auto rank = tensorType.getRank();
1403 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1404 "rank mismatch");
1405 int flatIndex = 0;
1406 int stride = 1;
1407 for (int i = rank - 1; i >= 0; --i) {
1408 flatIndex += indices[i] * stride;
1409 stride *= tensorType.getDimSize(i);
1410 }
1411 // Prevent out of bounds accesses. This can happen in invalid code that
1412 // will never execute.
1413 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1414 flatIndex < 0)
1415 return {};
1416 return fromElementsOp.getElements()[flatIndex];
1417 }
1418
1419 // If this is an elements attribute, query the value at the given indices.
1420 if (Attribute tensor = adaptor.getTensor()) {
1421 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1422 if (elementsAttr && elementsAttr.isValidIndex(indices))
1423 return elementsAttr.getValues<Attribute>()[indices];
1424 }
1425
1426 if (Value result = foldExtractAfterInsert(*this))
1427 return result;
1428
1429 return {};
1430}
1431
1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.add<ExtractFromTensorCast>(context);
1435}
1436
1439 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1440}
1441
1442//===----------------------------------------------------------------------===//
1443// FromElementsOp
1444//===----------------------------------------------------------------------===//
1445
1446void FromElementsOp::getAsmResultNames(
1447 function_ref<void(Value, StringRef)> setNameFn) {
1448 setNameFn(getResult(), "from_elements");
1449}
1450
1451void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1452 ValueRange elements) {
1453 assert(!elements.empty() && "expected at least one element");
1454 Type resultType = RankedTensorType::get(
1455 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1456 build(builder, result, resultType, elements);
1457}
1458
1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1460 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1461 return DenseElementsAttr::get(getType(), adaptor.getElements());
1462 return {};
1463}
1464
1465namespace {
1466
1467// Pushes the index_casts that occur before extractions to after the extract.
1468// This minimizes type conversion in some cases and enables the extract
1469// canonicalizer. This changes:
1470//
1471// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1472// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1473//
1474// to the following:
1475//
1476// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1477// %cast = arith.index_cast %extract : i32 to index
1478//
1479// to just %element.
1480//
1481// Consider expanding this to a template and handle all tensor cast
1482// operations.
1483struct ExtractElementFromIndexCast
1484 : public OpRewritePattern<tensor::ExtractOp> {
1485 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1486
1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1488 PatternRewriter &rewriter) const final {
1489 Location loc = extract.getLoc();
1490 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1491 if (!indexCast)
1492 return failure();
1493
1494 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1495
1496 auto newExtract = tensor::ExtractOp::create(
1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1498
1499 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1500 newExtract);
1501
1502 return success();
1503 }
1504};
1505
1506} // namespace
1507
1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1509 MLIRContext *context) {
1510 results.add<ExtractElementFromIndexCast>(context);
1511}
1512
1513//===----------------------------------------------------------------------===//
1514// GatherOp
1515//===----------------------------------------------------------------------===//
1516
1517void GatherOp::getAsmResultNames(
1518 function_ref<void(Value, StringRef)> setNameFn) {
1519 setNameFn(getResult(), "gather");
1520}
1521
1522/// Return the inferred result type for a gatherOp where:
1523/// - sourceType is the type of the source tensor gathered from
1524/// - indicesType is the type of the indices used to gather
1525/// - gatherDims are the dims along which the gather occurs.
1526/// Return a full rank or ranked-reduced variant of the type depending on
1527/// the value of rankReduced.
1528///
1529/// The leading dimensions of the index tensor give the result tensor its
1530/// leading dimensions.
1531/// The trailing dimensions of the result tensor are obtained from the source
1532/// tensor by setting the dimensions specified in gather_dims to `1` (if
1533/// rankedReduced is false), or skipping them (otherwise).
1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1535 RankedTensorType indicesType,
1536 ArrayRef<int64_t> gatherDims,
1537 bool rankReduced) {
1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1539 resultShape.reserve(resultShape.size() + sourceType.getRank());
1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1541 if (llvm::binary_search(gatherDims, idx)) {
1542 if (!rankReduced)
1543 resultShape.push_back(1);
1544 continue;
1545 }
1546 resultShape.push_back(sourceType.getDimSize(idx));
1547 }
1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1549}
1550
1551static LogicalResult
1554 StringRef gatherOrScatter, StringRef sourceOrDest) {
1555 if (dims.empty())
1556 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1557
1558 int64_t numGatherDims = dims.size();
1559 if (numGatherDims > rank)
1560 return op->emitOpError(gatherOrScatter)
1561 << "_dims overflow " << sourceOrDest << " rank";
1562 if (indices.empty() || indices.back() != numGatherDims)
1563 return op->emitOpError(gatherOrScatter)
1564 << "_dims length must match the size of last dimension of indices";
1565 for (int64_t val : dims) {
1566 if (val < 0)
1567 return op->emitOpError(gatherOrScatter)
1568 << "_dims value must be non-negative";
1569 if (val >= rank)
1570 return op->emitOpError(gatherOrScatter)
1571 << "_dims value must be smaller than " << sourceOrDest << " rank";
1572 }
1573 for (int64_t i = 1; i < numGatherDims; ++i) {
1574 if (dims[i - 1] >= dims[i])
1575 return op->emitOpError(gatherOrScatter)
1576 << "_dims values must be strictly increasing";
1577 }
1578 return success();
1579}
1580
1581LogicalResult GatherOp::verify() {
1582 int64_t sourceRank = getSourceType().getRank();
1583 ArrayRef<int64_t> gatherDims = getGatherDims();
1584 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1585 getIndicesType().getShape(), sourceRank,
1586 "gather", "source")))
1587 return failure();
1588
1589 RankedTensorType expectedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1593 if (getResultType() != expectedResultType &&
1594 getResultType() != expectedRankReducedResultType) {
1595 return emitOpError("result type "
1596 "mismatch: "
1597 "expected ")
1598 << expectedResultType << " or its rank-reduced variant "
1599 << expectedRankReducedResultType << " (got: " << getResultType()
1600 << ")";
1601 }
1602
1603 return success();
1604}
1605
1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1607 if (OpFoldResult reshapedSource = reshapeConstantSource(
1608 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1609 getResult().getType()))
1610 return reshapedSource;
1611 return {};
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// InsertOp
1616//===----------------------------------------------------------------------===//
1617
1618void InsertOp::getAsmResultNames(
1619 function_ref<void(Value, StringRef)> setNameFn) {
1620 setNameFn(getResult(), "inserted");
1621}
1622
1623LogicalResult InsertOp::verify() {
1624 // Verify the # indices match if we have a ranked type.
1625 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1626 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1627 return emitOpError("incorrect number of indices");
1628 return success();
1629}
1630
1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1632 Attribute scalar = adaptor.getScalar();
1633 Attribute dest = adaptor.getDest();
1634 if (scalar && dest)
1635 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1636 if (scalar == splatDest.getSplatValue<Attribute>())
1637 return dest;
1638 return {};
1639}
1640
1641//===----------------------------------------------------------------------===//
1642// GenerateOp
1643//===----------------------------------------------------------------------===//
1644
1645void GenerateOp::getAsmResultNames(
1646 function_ref<void(Value, StringRef)> setNameFn) {
1647 setNameFn(getResult(), "generated");
1648}
1649
1650LogicalResult GenerateOp::reifyResultShapes(
1651 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1652 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1653 int idx = 0;
1654 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1655 if (getType().isDynamicDim(dim)) {
1656 reifiedReturnShapes[0][dim] = getOperand(idx++);
1657 } else {
1658 reifiedReturnShapes[0][dim] =
1659 builder.getIndexAttr(getType().getDimSize(dim));
1660 }
1661 }
1662 return success();
1663}
1664
1665LogicalResult GenerateOp::verify() {
1666 // Ensure that the tensor type has as many dynamic dimensions as are
1667 // specified by the operands.
1668 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1669 if (getNumOperands() != resultType.getNumDynamicDims())
1670 return emitError("must have as many index operands as dynamic extents "
1671 "in the result type");
1672 return success();
1673}
1674
1675LogicalResult GenerateOp::verifyRegions() {
1676 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1677 // Ensure that region arguments span the index space.
1678 if (!llvm::all_of(getBody().getArgumentTypes(),
1679 [](Type ty) { return ty.isIndex(); }))
1680 return emitError("all body arguments must be index");
1681 if (getBody().getNumArguments() != resultTy.getRank())
1682 return emitError("must have one body argument per input dimension");
1683
1684 // Ensure that the region yields an element of the right type.
1685 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1686
1687 if (yieldOp.getValue().getType() != resultTy.getElementType())
1688 return emitOpError(
1689 "body must be terminated with a `yield` operation of the tensor "
1690 "element type");
1691
1692 return success();
1693}
1694
1695void GenerateOp::build(
1696 OpBuilder &b, OperationState &result, Type resultTy,
1697 ValueRange dynamicExtents,
1698 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1699 build(b, result, resultTy, dynamicExtents);
1700
1701 // Build and populate body.
1702 OpBuilder::InsertionGuard guard(b);
1703 Region *bodyRegion = result.regions.front().get();
1704 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1705 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1706 SmallVector<Location, 2> argumentLocs(rank, result.location);
1707 Block *bodyBlock =
1708 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1709 bodyBuilder(b, result.location, bodyBlock->getArguments());
1710}
1711
1712namespace {
1713
1714/// Canonicalizes tensor.generate operations with a constant
1715/// operand into the equivalent operation with the operand expressed in the
1716/// result type, instead. We also insert a type cast to make sure that the
1717/// resulting IR is still well-typed.
1718struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1719 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1720
1721 LogicalResult matchAndRewrite(GenerateOp generateOp,
1722 PatternRewriter &rewriter) const final {
1723 SmallVector<Value> foldedDynamicSizes;
1724 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1725 generateOp.getType(), generateOp.getDynamicExtents(),
1726 foldedDynamicSizes);
1727
1728 // Stop here if no dynamic size was promoted to static.
1729 if (foldedTensorType == generateOp.getType())
1730 return failure();
1731
1732 auto loc = generateOp.getLoc();
1733 auto newOp =
1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1735 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1736 newOp.getBody().begin());
1737 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1738 generateOp.getType(), newOp);
1739 return success();
1740 }
1741};
1742
1743/// Canonicalizes the pattern of the form
1744///
1745/// %tensor = tensor.generate %x {
1746/// ^bb0(%arg0: index):
1747/// <computation>
1748/// yield %1 : index
1749/// } : tensor<?xindex>
1750/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1751///
1752/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1753/// tensor.generate operation has no side-effects.
1754struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1755 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1756
1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1758 PatternRewriter &rewriter) const final {
1759 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1760 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1761 return failure();
1762
1763 IRMapping mapping;
1764 Block *body = &tensorFromElements.getBody().front();
1765 mapping.map(body->getArguments(), extract.getIndices());
1766 for (auto &op : body->without_terminator())
1767 rewriter.clone(op, mapping);
1768
1769 auto yield = cast<YieldOp>(body->getTerminator());
1770
1771 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1772 return success();
1773 }
1774};
1775
1776} // namespace
1777
1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779 MLIRContext *context) {
1780 // TODO: Move extract pattern to tensor::ExtractOp.
1781 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1782}
1783
1784//===----------------------------------------------------------------------===//
1785// RankOp
1786//===----------------------------------------------------------------------===//
1787
1788void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1789 setNameFn(getResult(), "rank");
1790}
1791
1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1793 // Constant fold rank when the rank of the operand is known.
1794 auto type = getOperand().getType();
1795 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1796 if (shapedType && shapedType.hasRank())
1797 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1798 return IntegerAttr();
1799}
1800
1801//===----------------------------------------------------------------------===//
1802// ReshapeOp
1803//===----------------------------------------------------------------------===//
1804
1805void ReshapeOp::getAsmResultNames(
1806 function_ref<void(Value, StringRef)> setNameFn) {
1807 setNameFn(getResult(), "reshape");
1808}
1809
1810static int64_t getNumElements(ShapedType type) {
1811 int64_t numElements = 1;
1812 for (auto dim : type.getShape())
1813 numElements *= dim;
1814 return numElements;
1815}
1816
1817LogicalResult ReshapeOp::verify() {
1818 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1819 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1820
1821 if (operandType.getElementType() != resultType.getElementType())
1822 return emitOpError("element types of source and destination tensor "
1823 "types should be the same");
1824
1825 int64_t shapeSize =
1826 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1827 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1828 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1829
1830 if (resultRankedType) {
1831 if (operandRankedType && resultRankedType.hasStaticShape() &&
1832 operandRankedType.hasStaticShape()) {
1833 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1834 return emitOpError("source and destination tensor should have the "
1835 "same number of elements");
1836 }
1837 if (ShapedType::isDynamic(shapeSize))
1838 return emitOpError("cannot use shape operand with dynamic length to "
1839 "reshape to statically-ranked tensor type");
1840 if (shapeSize != resultRankedType.getRank())
1841 return emitOpError(
1842 "length of shape operand differs from the result's tensor rank");
1843 }
1844 return success();
1845}
1846
1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1848 if (OpFoldResult reshapedSource = reshapeConstantSource(
1849 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1850 getResult().getType()))
1851 return reshapedSource;
1852
1853 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1854 // producer's input instead as the original tensor to reshape. This could
1855 // render such producer dead code.
1856 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1857 getSourceMutable().assign(reshapeOpProducer.getSource());
1858 return getResult();
1859 }
1860
1861 auto source = getSource();
1862 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1863 auto resultTy = dyn_cast<RankedTensorType>(getType());
1864 if (!sourceTy || !resultTy || sourceTy != resultTy)
1865 return {};
1866
1867 // If the source and result are both 0D or 1D tensors and have the same type,
1868 // the reshape has no effect, even if the tensor is dynamically shaped.
1869 if (sourceTy.getRank() <= 1)
1870 return source;
1871
1872 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1873 auto elements = fromElements.getElements();
1874 bool dynamicNoop =
1875 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1876 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1877 auto element = elements[id];
1878
1879 if (auto cst = getConstantIntValue(element)) {
1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1881 continue;
1882 }
1883
1884 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1885 dynamicNoop &= dimOp.getSource() == source;
1886
1887 auto cst = getConstantIntValue(dimOp.getIndex());
1888 dynamicNoop &=
1889 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1890 continue;
1891 }
1892
1893 dynamicNoop = false;
1894 break;
1895 }
1896
1897 if (dynamicNoop)
1898 return source;
1899 }
1900
1901 return {};
1902}
1903
1904//===----------------------------------------------------------------------===//
1905// Reassociative reshape ops
1906//===----------------------------------------------------------------------===//
1907
1908void CollapseShapeOp::getAsmResultNames(
1909 function_ref<void(Value, StringRef)> setNameFn) {
1910 setNameFn(getResult(), "collapsed");
1911}
1912
1913void ExpandShapeOp::getAsmResultNames(
1914 function_ref<void(Value, StringRef)> setNameFn) {
1915 setNameFn(getResult(), "expanded");
1916}
1917
1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1920 "invalid resultDim");
1921 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1922 if (llvm::is_contained(it.value(), resultDim))
1923 return it.index();
1924 llvm_unreachable("could not find reassociation group");
1925}
1926
1927FailureOr<SmallVector<OpFoldResult>>
1928ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1929 RankedTensorType expandedType,
1930 ArrayRef<ReassociationIndices> reassociation,
1931 ArrayRef<OpFoldResult> inputShape) {
1932 std::optional<SmallVector<OpFoldResult>> outputShape =
1933 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1934 inputShape);
1935 if (!outputShape)
1936 return failure();
1937 return *outputShape;
1938}
1939
1940SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1941 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1942}
1943
1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1945 Type resultType, Value src,
1946 ArrayRef<ReassociationIndices> reassociation,
1947 ArrayRef<OpFoldResult> outputShape) {
1948 auto [staticOutputShape, dynamicOutputShape] =
1949 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1950 build(builder, result, cast<RankedTensorType>(resultType), src,
1951 getReassociationIndicesAttribute(builder, reassociation),
1952 dynamicOutputShape, staticOutputShape);
1953}
1954
1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1956 Type resultType, Value src,
1957 ArrayRef<ReassociationIndices> reassociation) {
1958 SmallVector<OpFoldResult> inputShape =
1959 getMixedSizes(builder, result.location, src);
1960 auto tensorResultTy = cast<RankedTensorType>(resultType);
1961 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1962 builder, result.location, tensorResultTy, reassociation, inputShape);
1963 SmallVector<OpFoldResult> outputShapeOrEmpty;
1964 if (succeeded(outputShape)) {
1965 outputShapeOrEmpty = *outputShape;
1966 }
1967 build(builder, result, tensorResultTy, src, reassociation,
1968 outputShapeOrEmpty);
1969}
1970
1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1972 return getSymbolLessAffineMaps(getReassociationExprs());
1973}
1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1976 getReassociationIndices());
1977}
1978
1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1980 return getSymbolLessAffineMaps(getReassociationExprs());
1981}
1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1985}
1986
1987RankedTensorType CollapseShapeOp::inferCollapsedType(
1988 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1989 return inferCollapsedType(
1991 type.getContext(), reassociation)));
1992}
1993
1994/// Compute the RankedTensorType obtained by applying `reassociation` to
1995/// `type`.
1996RankedTensorType
1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1998 ArrayRef<AffineMap> reassociation) {
1999 auto shape = type.getShape();
2000 SmallVector<int64_t, 4> newShape;
2001 newShape.reserve(reassociation.size());
2002
2003 // Use the fact that reassociation is valid to simplify the logic: only use
2004 // each map's rank.
2005 assert(isReassociationValid(reassociation) && "invalid reassociation");
2006 unsigned currentDim = 0;
2007 for (AffineMap m : reassociation) {
2008 unsigned dim = m.getNumResults();
2009 auto band = shape.slice(currentDim, dim);
2010 int64_t size = 1;
2011 if (llvm::is_contained(band, ShapedType::kDynamic))
2012 size = ShapedType::kDynamic;
2013 else
2014 for (unsigned d = 0; d < dim; ++d)
2015 size *= shape[currentDim + d];
2016 newShape.push_back(size);
2017 currentDim += dim;
2018 }
2019
2020 return RankedTensorType::get(newShape, type.getElementType());
2021}
2022
2023void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2024 ArrayRef<ReassociationIndices> reassociation,
2025 ArrayRef<NamedAttribute> attrs) {
2026 auto resultType = inferCollapsedType(
2027 llvm::cast<RankedTensorType>(src.getType()),
2029 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2030 result.addAttribute(getReassociationAttrStrName(),
2031 getReassociationIndicesAttribute(b, reassociation));
2032 build(b, result, resultType, src, attrs);
2033}
2034
2035template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2037static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2040 if (failed(
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2042 return failure();
2043
2044 auto maps = op.getReassociationMaps();
2045 RankedTensorType expectedType =
2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2047 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2048 return op.emitOpError("expected collapsed type to be ")
2049 << expectedType << ", but got " << collapsedType;
2050 return success();
2051}
2052
2053LogicalResult ExpandShapeOp::verify() {
2054 auto srcType = getSrcType();
2055 auto resultType = getResultType();
2056
2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2058 return emitOpError("expected number of static shape dims to be equal to "
2059 "the output rank (")
2060 << resultType.getRank() << ") but found "
2061 << getStaticOutputShape().size() << " inputs instead";
2062
2063 if ((int64_t)getOutputShape().size() !=
2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2065 return emitOpError("mismatch in dynamic dims in output_shape and "
2066 "static_output_shape: static_output_shape has ")
2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2068 << " dynamic dims while output_shape has " << getOutputShape().size()
2069 << " values";
2070
2071 return verifyTensorReshapeOp(*this, resultType, srcType);
2072}
2073
2074LogicalResult CollapseShapeOp::verify() {
2075 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2076}
2077
2078namespace {
2079/// Reshape of a splat constant can be replaced with a constant of the result
2080/// type.
2081template <typename TensorReshapeOp>
2082struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2083 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2085 PatternRewriter &rewriter) const override {
2086 DenseElementsAttr attr;
2087 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2088 return failure();
2089 if (!attr || !attr.isSplat())
2090 return failure();
2091 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2092 reshapeOp.getResultType(), attr.getRawData());
2093 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2094 return success();
2095 }
2096};
2097
2098// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2099template <typename TensorReshapeOp>
2100class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2101public:
2102 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2103
2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2105 PatternRewriter &rewriter) const override {
2106 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2108 return failure();
2109
2110 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2112 return success();
2113 }
2114};
2115
2116/// Reshape of a FromElements can be replaced with a FromElements of the
2117/// result type
2118template <typename TensorReshapeOp>
2119struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter) const override {
2123 auto fromElements =
2124 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2125 if (!fromElements)
2126 return failure();
2127
2128 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2129
2130 if (!shapedTy.hasStaticShape())
2131 return failure();
2132
2133 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2134 fromElements.getElements());
2135 return success();
2136 }
2137};
2138
2139// Fold CastOp into CollapseShapeOp when adding static information.
2140struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2141 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2142
2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2144 PatternRewriter &rewriter) const override {
2145 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2146 if (!tensor::canFoldIntoConsumerOp(castOp))
2147 return failure();
2148
2149 RankedTensorType srcType =
2150 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2152 srcType, collapseShapeOp.getReassociationMaps());
2153
2154 if (newResultType == collapseShapeOp.getResultType()) {
2155 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2157 });
2158 } else {
2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2160 newResultType, castOp.getSource(),
2161 collapseShapeOp.getReassociation());
2162 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2164 }
2165 return success();
2166 }
2167};
2168
2169/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2170/// matching constant output_shape operands of the expand. This makes the
2171/// `tensor.expand_shape` more static and creates a consumer cast that can be
2172/// propagated further.
2173struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2174 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2175
2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2177 PatternRewriter &rewriter) const override {
2178 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2179 if (!canFoldIntoConsumerOp(castOp))
2180 return failure();
2181
2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2183 SmallVector<ReassociationIndices, 4> reassoc =
2184 expandOp.getReassociationIndices();
2185
2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2187 SmallVector<Value> dynamicOutputShape;
2188 auto outputIt = expandOp.getOutputShape().begin();
2189
2190 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2191 for (uint64_t outDim : innerReassoc) {
2192 if (ShapedType::isStatic(newOutputShape[outDim]))
2193 continue;
2194
2195 // If the cast's src type is dynamic, don't infer any of the
2196 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2197 // least one of the expanded dimensions to be dynamic if the input is
2198 // dynamic.
2199 Value val = *outputIt;
2200 ++outputIt;
2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2202 dynamicOutputShape.push_back(val);
2203 continue;
2204 }
2205
2206 APInt cst;
2207 if (matchPattern(val, m_ConstantInt(&cst))) {
2208 newOutputShape[outDim] = cst.getSExtValue();
2209 } else {
2210 dynamicOutputShape.push_back(val);
2211 }
2212 }
2213 }
2214
2215 // Couldn't match any values, nothing to change
2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2217 return failure();
2218
2219 // Calculate the input shape from the output
2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2221 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2222 for (auto outDim : reassoc[inDim]) {
2223 auto ofr = newOutputShape[outDim];
2224 if (ShapedType::isDynamic(ofr)) {
2225 newInputShape[inDim] = ShapedType::kDynamic;
2226 break;
2227 }
2228 newInputShape[inDim] *= ofr;
2229 }
2230 }
2231
2232 SmallVector<OpFoldResult> outputOfr =
2233 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2234 auto inputType = RankedTensorType::get(
2235 newInputShape, expandOp.getSrcType().getElementType());
2236 auto outputType = RankedTensorType::get(
2237 newOutputShape, expandOp.getSrcType().getElementType());
2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2239 expandOp.getSrc());
2240 auto newExpand = ExpandShapeOp::create(
2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2242 expandOp.getReassociationIndices(), outputOfr);
2243 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2244 newExpand.getResult());
2245 return success();
2246 }
2247};
2248} // namespace
2249
2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2251 MLIRContext *context) {
2252 results.add<
2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2255 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2256 FoldReshapeWithSplat<ExpandShapeOp>,
2257 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2258}
2259
2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2261 MLIRContext *context) {
2262 results.add<
2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2265 tensor::DimOp, RankedTensorType>,
2266 FoldReshapeWithConstant<CollapseShapeOp>,
2267 FoldReshapeWithSplat<CollapseShapeOp>,
2268 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2269 context);
2270}
2271
2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2274 adaptor.getOperands());
2275}
2276
2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2279 adaptor.getOperands());
2280}
2281
2282//===----------------------------------------------------------------------===//
2283// ExtractSliceOp
2284//===----------------------------------------------------------------------===//
2285
2286void ExtractSliceOp::getAsmResultNames(
2287 function_ref<void(Value, StringRef)> setNameFn) {
2288 setNameFn(getResult(), "extracted_slice");
2289}
2290
2291/// An extract_slice result type can be inferred, when it is not
2292/// rank-reduced, from the source type and the static representation of
2293/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2294RankedTensorType
2295ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2296 ArrayRef<int64_t> staticSizes) {
2297 // An extract_slice op may specify only a leading subset of offset/sizes/
2298 // strides in which case we complete with offset=0, sizes from memref type
2299 // and strides=1.
2300 assert(static_cast<int64_t>(staticSizes.size()) ==
2301 sourceTensorType.getRank() &&
2302 "unexpected staticSizes not equal to rank of source");
2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2304 sourceTensorType.getEncoding());
2305}
2306
2307// TODO: This uses neither offsets nor strides!
2308RankedTensorType
2309ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2310 ArrayRef<OpFoldResult> sizes) {
2311 SmallVector<int64_t> staticSizes;
2312 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2313
2314 assert(static_cast<int64_t>(staticSizes.size()) ==
2315 sourceTensorType.getRank() &&
2316 "unexpected staticSizes not equal to rank of source");
2317 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2318 sourceTensorType.getEncoding());
2319}
2320
2321/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2322/// number of sizes), drop as many size 1 as needed to produce an inferred
2323/// type with the desired rank.
2324///
2325/// Note that there may be multiple ways to compute this rank-reduced type:
2326/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2327///
2328/// To disambiguate, this function always drops the first 1 sizes occurrences.
2329RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2330 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2331 ArrayRef<int64_t> sizes) {
2332 // Type inferred in the absence of rank-reducing behavior.
2333 auto inferredType = llvm::cast<RankedTensorType>(
2334 inferResultType(sourceRankedTensorType, sizes));
2335 int rankDiff = inferredType.getRank() - desiredResultRank;
2336 if (rankDiff > 0) {
2337 auto shape = inferredType.getShape();
2338 llvm::SmallBitVector dimsToProject =
2339 getPositionsOfShapeOne(rankDiff, shape);
2340 SmallVector<int64_t> projectedShape;
2341 // Best effort rank-reducing: drop 1s in order.
2342 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2343 if (!dimsToProject.test(pos))
2344 projectedShape.push_back(shape[pos]);
2345 inferredType =
2346 RankedTensorType::get(projectedShape, inferredType.getElementType());
2347 }
2348 return inferredType;
2349}
2350
2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353 ArrayRef<OpFoldResult> sizes) {
2354 SmallVector<int64_t> staticSizes;
2355 SmallVector<Value> dynamicSizes;
2356 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2357 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2358 desiredResultRank, sourceRankedTensorType, staticSizes);
2359}
2360
2361/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2362/// result type. If the type passed is nullptr, it is inferred.
2363void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2364 RankedTensorType resultType, Value source,
2365 ArrayRef<OpFoldResult> offsets,
2366 ArrayRef<OpFoldResult> sizes,
2367 ArrayRef<OpFoldResult> strides,
2368 ArrayRef<NamedAttribute> attrs) {
2369 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2370 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2371 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2372 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2373 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2374 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2375 // Structuring implementation this way avoids duplication between builders.
2376 if (!resultType) {
2377 resultType = llvm::cast<RankedTensorType>(
2378 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2379 }
2380 result.addAttributes(attrs);
2381 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2382 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2383 b.getDenseI64ArrayAttr(staticSizes),
2384 b.getDenseI64ArrayAttr(staticStrides));
2385}
2386
2387/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2388/// result type.
2389void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2390 ArrayRef<OpFoldResult> offsets,
2391 ArrayRef<OpFoldResult> sizes,
2392 ArrayRef<OpFoldResult> strides,
2393 ArrayRef<NamedAttribute> attrs) {
2394 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2395}
2396
2397/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2398/// a Range vector.
2399void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2400 ArrayRef<Range> ranges,
2401 ArrayRef<NamedAttribute> attrs) {
2402 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2403 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2404}
2405
2406/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2407/// the type passed is nullptr, it is inferred.
2408void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2409 RankedTensorType resultType, Value source,
2410 ValueRange offsets, ValueRange sizes,
2411 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2412 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2413 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2414 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2415 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2416 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2417 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2418 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2419}
2420
2421/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2422void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2423 ValueRange offsets, ValueRange sizes,
2424 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2425 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2426}
2427
2429 Operation *op,
2430 RankedTensorType expectedType) {
2431 switch (result) {
2433 return success();
2435 return op->emitError("expected rank to be smaller or equal to ")
2436 << "the other rank. ";
2438 return op->emitError("expected type to be ")
2439 << expectedType << " or a rank-reduced version. (size mismatch) ";
2441 return op->emitError("expected element type to be ")
2442 << expectedType.getElementType();
2443 default:
2444 llvm_unreachable("unexpected extract_slice op verification result");
2445 }
2446}
2447
2448/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2449/// result type, offsets set to 0 and strides set to 1.
2450void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2451 RankedTensorType resultType, Value source,
2452 ArrayRef<OpFoldResult> sizes,
2453 ArrayRef<NamedAttribute> attrs) {
2454 Attribute zeroIdxAttr = b.getIndexAttr(0);
2455 Attribute oneIdxAttr = b.getIndexAttr(1);
2456 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2457 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2458 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2459}
2460
2461/// Verifier for ExtractSliceOp.
2462LogicalResult ExtractSliceOp::verify() {
2463 RankedTensorType sourceType = getSourceType();
2464
2465 // Verify result type against inferred type.
2466 RankedTensorType expectedType =
2467 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2470 return produceSliceErrorMsg(result, *this, expectedType);
2471
2472 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2473 // to the source tensor.
2474 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2475 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2476 getStaticStrides(), /*generateErrorMessage=*/true);
2477 if (!boundsResult.isValid)
2478 return getOperation()->emitError(boundsResult.errorMessage);
2479
2480 return success();
2481}
2482
2483llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2484 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2485}
2486
2487FailureOr<Value>
2488ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2489 ArrayRef<int64_t> desiredShape) {
2490 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2491 assert(sourceTensorType && "not a ranked tensor type");
2492 auto sourceShape = sourceTensorType.getShape();
2493 if (sourceShape.equals(desiredShape))
2494 return value;
2495 auto maybeRankReductionMask =
2496 mlir::computeRankReductionMask(sourceShape, desiredShape);
2497 if (!maybeRankReductionMask)
2498 return failure();
2500 b, loc, value,
2501 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2502}
2503
2504LogicalResult ExtractSliceOp::reifyResultShapes(
2505 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2506 reifiedReturnShapes.resize(1);
2507 reifiedReturnShapes[0].reserve(getType().getRank());
2508 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2509 llvm::SmallBitVector droppedDims = getDroppedDims();
2510 for (const auto &size : enumerate(mixedSizes)) {
2511 if (droppedDims.test(size.index()))
2512 continue;
2513 reifiedReturnShapes[0].push_back(size.value());
2514 }
2515 return success();
2516}
2517
2518namespace {
2519/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2520/// This essentially pushes memref_cast past its consuming slice when
2521/// `canFoldIntoConsumerOp` is true.
2522///
2523/// Example:
2524/// ```
2525/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2526/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2527/// tensor<3x4xf32>
2528/// ```
2529/// is rewritten into:
2530/// ```
2531/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2532/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2533/// ```
2534class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2535public:
2536 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2537
2538 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2539 PatternRewriter &rewriter) const override {
2540 // Any constant operand, just return to let the constant folder kick in.
2541 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2542 return matchPattern(operand, matchConstantIndex());
2543 }))
2544 return failure();
2545
2546 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2547 if (!castOp)
2548 return failure();
2549
2550 if (!canFoldIntoConsumerOp(castOp))
2551 return failure();
2552
2553 // Pattern does not apply if the produced op would not verify.
2554 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2555 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2556 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2557 sliceOp.getStaticStrides());
2558 if (!sliceResult.isValid)
2559 return failure();
2560
2561 // Create folded extract.
2562 Location loc = sliceOp.getLoc();
2563 Value newResult = ExtractSliceOp::create(
2564 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2565 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2566 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2567 sliceOp.getStaticStrides());
2568 rewriter.replaceOp(sliceOp, newResult);
2569 return success();
2570 }
2571};
2572
2573/// Slice elements from `values` into `outValues`. `counts` represents the
2574/// numbers of elements to stride in the original values for each dimension.
2575/// The output values can be used to construct a DenseElementsAttr.
2576template <typename IterTy, typename ElemTy>
2577static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2578 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2579 ArrayRef<int64_t> strides,
2580 llvm::SmallVectorImpl<ElemTy> *outValues) {
2581 assert(offsets.size() == sizes.size());
2582 assert(offsets.size() == strides.size());
2583 if (offsets.empty())
2584 return;
2585
2586 int64_t offset = offsets.front();
2587 int64_t size = sizes.front();
2588 int64_t stride = strides.front();
2589 if (offsets.size() == 1) {
2590 for (int64_t i = 0; i < size; ++i, offset += stride)
2591 outValues->push_back(*(values + offset));
2592
2593 return;
2594 }
2595
2596 for (int64_t i = 0; i < size; ++i, offset += stride) {
2597 auto begin = values + offset * counts.front();
2598 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2599 offsets.drop_front(), sizes.drop_front(),
2600 strides.drop_front(), outValues);
2601 }
2602}
2603
2604/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2605/// folded operation might introduce more constant data; Users can control
2606/// their heuristics by the control function.
2607class ConstantOpExtractSliceFolder final
2608 : public OpRewritePattern<ExtractSliceOp> {
2609public:
2610 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2611
2612 ConstantOpExtractSliceFolder(MLIRContext *context,
2614 : OpRewritePattern<ExtractSliceOp>(context),
2615 controlFn(std::move(controlFn)) {}
2616
2617 LogicalResult matchAndRewrite(ExtractSliceOp op,
2618 PatternRewriter &rewriter) const override {
2619 DenseElementsAttr attr;
2620 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2621 return failure();
2622
2623 // A constant splat is handled by fold().
2624 if (attr.isSplat())
2625 return failure();
2626
2627 // Dynamic result shape is not supported.
2628 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2629 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2630 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2631 return failure();
2632
2633 // Customized control over the folding.
2634 if (!controlFn(op))
2635 return failure();
2636
2637 int64_t count = sourceType.getNumElements();
2638 if (count == 0)
2639 return failure();
2640
2641 // Check if there are any dynamic parts, which are not supported.
2642 auto offsets = op.getStaticOffsets();
2643 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2644 return failure();
2645 auto sizes = op.getStaticSizes();
2646 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2647 return failure();
2648 auto strides = op.getStaticStrides();
2649 if (llvm::is_contained(strides, ShapedType::kDynamic))
2650 return failure();
2651
2652 // Compute the stride for each dimension.
2653 SmallVector<int64_t> counts;
2654 ArrayRef<int64_t> shape = sourceType.getShape();
2655 counts.reserve(shape.size());
2656 for (int64_t v : shape) {
2657 count = count / v;
2658 counts.push_back(count);
2659 }
2660
2661 // New attribute constructed by the sliced values.
2662 DenseElementsAttr newAttr;
2663
2664 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2665 SmallVector<APInt> outValues;
2666 outValues.reserve(sourceType.getNumElements());
2667 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2668 elems.begin(), counts, offsets, sizes, strides, &outValues);
2669 newAttr = DenseElementsAttr::get(resultType, outValues);
2670 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2671 SmallVector<APFloat> outValues;
2672 outValues.reserve(sourceType.getNumElements());
2673 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2674 elems.begin(), counts, offsets, sizes, strides, &outValues);
2675 newAttr = DenseElementsAttr::get(resultType, outValues);
2676 }
2677
2678 if (newAttr) {
2679 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2680 return success();
2681 }
2682
2683 return failure();
2684 }
2685
2686private:
2687 /// This additionally controls whether the fold happens or not. Users can
2688 /// impose their heuristics in the function.
2690};
2691
2692} // namespace
2693
2696 const ControlConstantExtractSliceFusionFn &controlFn) {
2697 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2698}
2699
2700/// Return the canonical type of the result of an extract_slice op.
2702 RankedTensorType operator()(ExtractSliceOp op,
2703 ArrayRef<OpFoldResult> mixedOffsets,
2704 ArrayRef<OpFoldResult> mixedSizes,
2705 ArrayRef<OpFoldResult> mixedStrides) {
2706 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2707 op.getType().getRank(), op.getSourceType(), mixedSizes);
2708 }
2709};
2710
2711/// A canonicalizer wrapper to replace ExtractSliceOps.
2713 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2714 ExtractSliceOp newOp) {
2715 Value replacement = newOp.getResult();
2716 if (replacement.getType() != op.getType())
2717 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2718 replacement);
2719 rewriter.replaceOp(op, replacement);
2720 }
2721};
2722
2723void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2724 MLIRContext *context) {
2725 results.add<
2726 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2727 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2728 ExtractSliceOpCastFolder>(context);
2729}
2730
2731//
2732static LogicalResult
2733foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2734 ShapedType shapedType) {
2735 OpBuilder b(op.getContext());
2736 for (OpFoldResult ofr : op.getMixedOffsets())
2737 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2738 return failure();
2739 // Rank-reducing noops only need to inspect the leading dimensions:
2740 // llvm::zip is appropriate.
2741 auto shape = shapedType.getShape();
2742 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2743 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2744 return failure();
2745 for (OpFoldResult ofr : op.getMixedStrides())
2746 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2747 return failure();
2748 return success();
2749}
2750
2751/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2752/// slice, we can return the InsertSliceOp's source directly.
2753// TODO: This only checks the immediate producer; extend to go up the
2754// insert/extract chain if the slices are disjoint.
2755static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2756 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2757
2758 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2759 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2760 insertOp.isSameAs(extractOp, isSame))
2761 return insertOp.getSource();
2762
2763 return {};
2764}
2765
2766OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2767 if (OpFoldResult reshapedSource = reshapeConstantSource(
2768 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2769 getResult().getType()))
2770 return reshapedSource;
2771 if (getSourceType() == getType() &&
2773 return this->getSource();
2774 if (Value slice = foldExtractAfterInsertSlice(*this))
2775 return slice;
2776
2777 return OpFoldResult();
2778}
2779
2781 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2782 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2783 unsigned rank = rankedTensorType.getRank();
2784 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2786 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2787 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2788 offsets, sizes, strides);
2789}
2790
2791//===----------------------------------------------------------------------===//
2792// InsertSliceOp
2793//===----------------------------------------------------------------------===//
2794
2795void InsertSliceOp::getAsmResultNames(
2796 function_ref<void(Value, StringRef)> setNameFn) {
2797 setNameFn(getResult(), "inserted_slice");
2798}
2799
2800// Build a InsertSliceOp with mixed static and dynamic entries.
2801void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2802 Value dest, ArrayRef<OpFoldResult> offsets,
2804 ArrayRef<OpFoldResult> strides,
2806 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2807 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2808 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2809 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2810 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2811 result.addAttributes(attrs);
2812 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2813 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2814 b.getDenseI64ArrayAttr(staticSizes),
2815 b.getDenseI64ArrayAttr(staticStrides));
2816}
2817
2818/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2819/// Range vector.
2820void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2821 Value dest, ArrayRef<Range> ranges,
2822 ArrayRef<NamedAttribute> attrs) {
2823 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2824 build(b, result, source, dest, offsets, sizes, strides, attrs);
2825}
2826
2827// Build a InsertSliceOp with dynamic entries.
2828void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2829 Value dest, ValueRange offsets, ValueRange sizes,
2830 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2831 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2832 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2833 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2834 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2835 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2836 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2837 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2838}
2839
2840/// Rank-reducing type verification for both InsertSliceOp and
2841/// ParallelInsertSliceOp.
2843 RankedTensorType srcType, RankedTensorType dstType,
2844 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2845 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2846 // insert_slice is the inverse of extract_slice, use the same type
2847 // inference.
2848 RankedTensorType expected =
2849 ExtractSliceOp::inferResultType(dstType, staticSizes);
2850 if (expectedType)
2851 *expectedType = expected;
2852 return isRankReducedType(expected, srcType);
2853}
2854
2855/// Verifier for InsertSliceOp.
2856LogicalResult InsertSliceOp::verify() {
2857 // Verify result type against inferred type.
2858 RankedTensorType expectedType;
2860 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2861 getStaticSizes(), getStaticStrides(), &expectedType);
2863 return produceSliceErrorMsg(result, *this, expectedType);
2864
2865 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2866 // to the destination tensor.
2867 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2868 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2869 getStaticStrides(), /*generateErrorMessage=*/true);
2870 if (!boundsResult.isValid)
2871 return getOperation()->emitError(boundsResult.errorMessage);
2872
2873 return success();
2874}
2875
2876/// If we have two consecutive InsertSliceOp writing to the same slice, we
2877/// can mutate the second InsertSliceOp's destination to the first one's.
2878///
2879/// Example:
2880///
2881/// ```mlir
2882/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2883/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2884/// ```
2885///
2886/// folds into:
2887///
2888/// ```mlir
2889/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2890/// ```
2891///
2892/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2893static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2894 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2895
2896 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2897 if (!prevInsertOp ||
2898 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2899 !prevInsertOp.isSameAs(insertOp, isSame))
2900 return failure();
2901
2902 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2903 return success();
2904}
2905
2906/// Folds round-trip extract/insert slice op pairs.
2907/// Example:
2908/// ```mlir
2909/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2910/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2911/// ```
2912/// can be folded into %val.
2913static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2914 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2915
2916 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2917 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2918 !extractOp.isSameAs(insertOp, isSame))
2919 return nullptr;
2920
2921 return extractOp.getSource();
2922}
2923
2924OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2925 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2926 getSourceType() == getType() &&
2928 return this->getSource();
2929 if (succeeded(foldInsertAfterInsertSlice(*this)))
2930 return getResult();
2931 if (auto result = foldInsertAfterExtractSlice(*this))
2932 return result;
2933 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2934 return getDest();
2935 return OpFoldResult();
2936}
2937
2938LogicalResult InsertSliceOp::reifyResultShapes(
2939 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2940 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2941 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2942 return success();
2943}
2944
2945namespace {
2946/// Pattern to rewrite a insert_slice op with constant arguments.
2947///
2948/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2949template <typename InsertOpTy>
2950class InsertSliceOpConstantArgumentFolder final
2951 : public OpRewritePattern<InsertOpTy> {
2952public:
2953 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2954
2955 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2956 PatternRewriter &rewriter) const override {
2957 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2958 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2959 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2960
2961 // No constant operands were folded, just return;
2962 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2963 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2964 failed(foldDynamicStrideList(mixedStrides)))
2965 return failure();
2966
2967 // Pattern does not apply if the produced op would not verify.
2968 SliceBoundsVerificationResult sliceResult =
2969 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2970 mixedOffsets, mixedSizes, mixedStrides);
2971 if (!sliceResult.isValid)
2972 return failure();
2973
2974 // Create the new op in canonical form.
2975 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2976 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2977 mixedSizes);
2978 Value toInsert = insertSliceOp.getSource();
2979 if (sourceType != insertSliceOp.getSourceType()) {
2980 OpBuilder::InsertionGuard g(rewriter);
2981 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2982 // is that the insertion point is just before the InParallelOp in
2983 // the parallel case.
2984 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2985 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2986 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2987 sourceType, toInsert);
2988 }
2989 rewriter.replaceOpWithNewOp<InsertOpTy>(
2990 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2991 mixedSizes, mixedStrides);
2992 return success();
2993 }
2994};
2995
2996/// Fold tensor_casts with insert_slice operations. If the source or
2997/// destination tensor is a tensor_cast that removes static type information,
2998/// the cast is folded into the insert_slice operation. E.g.:
2999///
3000/// ```mlir
3001/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3002/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3003/// ```
3004///
3005/// folds into:
3006///
3007/// ```mlir
3008/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3009/// ```
3010///
3011/// Note: When folding a cast on the destination tensor, the result of the
3012/// insert_slice operation is casted to ensure that the type of the result did
3013/// not change.
3014///
3015/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3016template <typename InsertOpTy>
3017struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3018 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3019
3020 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3021 PatternRewriter &rewriter) const override {
3022 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3023 return matchPattern(operand, matchConstantIndex());
3024 }))
3025 return failure();
3026
3027 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3028 auto castOp = v.getDefiningOp<tensor::CastOp>();
3029 if (!castOp || !canFoldIntoConsumerOp(castOp))
3030 return std::nullopt;
3031 return castOp.getSource();
3032 };
3033 std::optional<Value> sourceCastSource =
3034 getSourceOfCastOp(insertSliceOp.getSource());
3035 std::optional<Value> destCastSource =
3036 getSourceOfCastOp(insertSliceOp.getDest());
3037 if (!sourceCastSource && !destCastSource)
3038 return failure();
3039
3040 auto src =
3041 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3042 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3043 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3044 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3045 if (!srcType || !dstType)
3046 return failure();
3047
3048 // The tensor.cast source could have additional static information not seen
3049 // in the insert slice op static sizes, so we ignore dynamic dims when
3050 // computing the rank reduction mask.
3051 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3052 auto rankReductionMask = computeRankReductionMask(
3053 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3054 if (!rankReductionMask.has_value())
3055 return failure();
3056 // Replace dimensions in the insert slice op with corresponding static dims
3057 // from the cast source type. If the insert slice sizes have static dims
3058 // that are not static in the tensor.cast source (i.e., when the cast op
3059 // casts a dynamic dim to static), the dim should not be replaced, and the
3060 // pattern will fail later in `verifyInsertSliceOp`.
3061 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3062 int64_t rankReducedIdx = 0;
3063 for (auto [idx, size] : enumerate(staticSizes)) {
3064 if (!rankReductionMask.value().contains(idx) &&
3065 !srcType.isDynamicDim(rankReducedIdx)) {
3066 mixedSizes[idx] = getAsIndexOpFoldResult(
3067 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3068 size = srcType.getDimSize(rankReducedIdx++);
3069 }
3070 }
3071
3072 // Pattern does not apply if the produced op would not verify.
3073 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3074 staticSizes, insertSliceOp.getStaticStrides()) !=
3075 SliceVerificationResult::Success)
3076 return failure();
3077 SliceBoundsVerificationResult sliceResult =
3078 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3079 mixedSizes, insertSliceOp.getMixedStrides());
3080 if (!sliceResult.isValid)
3081 return failure();
3082
3083 Operation *replacement =
3084 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3085 insertSliceOp.getMixedOffsets(), mixedSizes,
3086 insertSliceOp.getMixedStrides());
3087
3088 // In the parallel case there is no result and so nothing to cast.
3089 bool isParallelInsert =
3090 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3091 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3092 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3093 insertSliceOp.getDestType(),
3094 replacement->getResult(0));
3095 }
3096 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3097 return success();
3098 }
3099};
3100
3101/// If additional static type information can be deduced from a insert_slice's
3102/// size operands, insert an explicit cast of the op's source operand. This
3103/// enables other canonicalization patterns that are matching for tensor_cast
3104/// ops such as `ForOpTensorCastFolder` in SCF.
3105///
3106/// Example:
3107///
3108/// ```mlir
3109/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3110/// : tensor<?x?xf32> into ...
3111/// ```
3112///
3113/// folds into:
3114///
3115/// ```mlir
3116/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3117/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3118/// : tensor<64x64xf32> into ...
3119/// ```
3120///
3121/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3122template <typename InsertOpTy>
3123struct InsertSliceOpSourceCastInserter final
3124 : public OpRewritePattern<InsertOpTy> {
3125 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3126
3127 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3128 PatternRewriter &rewriter) const override {
3129 RankedTensorType srcType = insertSliceOp.getSourceType();
3130 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3131 return failure();
3132 SmallVector<int64_t> newSrcShape(srcType.getShape());
3133 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3134 if (std::optional<int64_t> constInt =
3135 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3136 // Bail on invalid IR.
3137 if (*constInt < 0)
3138 return failure();
3139 newSrcShape[i] = *constInt;
3140 }
3141 }
3142 if (!hasValidSizesOffsets(newSrcShape))
3143 return failure();
3144
3145 RankedTensorType newSrcType = RankedTensorType::get(
3146 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3147 if (srcType == newSrcType ||
3148 !preservesStaticInformation(srcType, newSrcType) ||
3149 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3150 return failure();
3151
3152 // newSrcType is:
3153 // 1) Different from srcType.
3154 // 2) "More static" than srcType.
3155 // 3) Cast-compatible with srcType.
3156 // Insert the cast.
3157 OpBuilder::InsertionGuard g(rewriter);
3158 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3159 // that the insertion point is just before the InParallelOp in the
3160 // parallel case.
3161 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3162 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3163 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3164 newSrcType, insertSliceOp.getSource());
3165 rewriter.replaceOpWithNewOp<InsertOpTy>(
3166 insertSliceOp, cast, insertSliceOp.getDest(),
3167 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3168 insertSliceOp.getMixedStrides());
3169 return success();
3170 }
3171};
3172} // namespace
3173
3174llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3175 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3176}
3177
3178void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3179 MLIRContext *context) {
3180 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3181 InsertSliceOpCastFolder<InsertSliceOp>,
3182 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3183}
3184
3186 Location loc,
3187 Value tensor,
3188 Value dest) {
3189 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3190 unsigned rank = rankedTensorType.getRank();
3191 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3192 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3193 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3194 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3195 sizes, strides);
3196}
3197
3198//===----------------------------------------------------------------------===//
3199// PadOp
3200//===----------------------------------------------------------------------===//
3201
3202void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3203 setNameFn(getResult(), "padded");
3204}
3205
3206LogicalResult PadOp::verify() {
3207 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3208 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3209 auto expectedType =
3210 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3211 if (!expectedType) {
3212 return emitError("failed to infer expectedType from sourceType ")
3213 << sourceType << ", specified resultType is " << resultType;
3214 }
3215 if (resultType.getRank() != expectedType.getRank()) {
3216 return emitError("specified type ")
3217 << resultType << " does not match the inferred type "
3218 << expectedType;
3219 }
3220 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3221 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3222 continue;
3223 if (expectedType.isDynamicDim(i))
3224 continue;
3225 return emitError("specified type ")
3226 << resultType << " does not match the inferred type "
3227 << expectedType;
3228 }
3229
3230 return success();
3231}
3232
3233LogicalResult PadOp::verifyRegions() {
3234 auto &region = getRegion();
3235 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3236 Block &block = region.front();
3237 if (block.getNumArguments() != rank)
3238 return emitError("expected the block to have ") << rank << " arguments";
3239
3240 // Note: the number and type of yield values are checked in the YieldOp.
3241 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3242 if (!en.value().isIndex())
3243 return emitOpError("expected block argument ")
3244 << (en.index() + 1) << " to be an index";
3245 }
3246
3247 // Ensure that the region yields an element of the right type.
3248 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3249 if (yieldOp.getValue().getType() !=
3250 llvm::cast<ShapedType>(getType()).getElementType())
3251 return emitOpError("expected yield type to match shape element type");
3252
3253 return success();
3254}
3255
3256RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3257 ArrayRef<int64_t> staticLow,
3258 ArrayRef<int64_t> staticHigh,
3259 ArrayRef<int64_t> resultShape) {
3260 unsigned rank = sourceType.getRank();
3261 if (staticLow.size() != rank)
3262 return RankedTensorType();
3263 if (staticHigh.size() != rank)
3264 return RankedTensorType();
3265 if (!resultShape.empty() && resultShape.size() != rank)
3266 return RankedTensorType();
3267
3268 SmallVector<int64_t, 4> inferredShape;
3269 for (auto i : llvm::seq<unsigned>(0, rank)) {
3270 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3271 staticHigh[i] == ShapedType::kDynamic) {
3272 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3273 : resultShape[i]);
3274 } else {
3275 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3276 assert((resultShape.empty() || size == resultShape[i] ||
3277 resultShape[i] == ShapedType::kDynamic) &&
3278 "mismatch between inferred shape and result shape");
3279 inferredShape.push_back(size);
3280 }
3281 }
3282
3283 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3284}
3285
3286void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3287 Value source, ArrayRef<int64_t> staticLow,
3288 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3289 bool nofold, ArrayRef<NamedAttribute> attrs) {
3290 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3291 if (!resultType)
3292 resultType = inferResultType(sourceType, staticLow, staticHigh);
3293 result.addAttributes(attrs);
3294 build(b, result, resultType, source, low, high,
3295 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3296 nofold ? b.getUnitAttr() : UnitAttr());
3297}
3298
3299void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3300 Value source, ValueRange low, ValueRange high, bool nofold,
3301 ArrayRef<NamedAttribute> attrs) {
3302 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3303 unsigned rank = sourceType.getRank();
3304 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3305 build(b, result, resultType, source, staticVector, staticVector, low, high,
3306 nofold, attrs);
3307}
3308
3309void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3310 Value source, ArrayRef<OpFoldResult> low,
3311 ArrayRef<OpFoldResult> high, bool nofold,
3312 ArrayRef<NamedAttribute> attrs) {
3313 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3314 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3315 SmallVector<int64_t, 4> staticLow, staticHigh;
3316 // staticLow and staticHigh have full information of the padding config.
3317 // This will grow staticLow and staticHigh with 1 value. If the config is
3318 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3319 // value as well.
3320 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3321 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3322 if (!resultType) {
3323 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3324 }
3325 assert(llvm::isa<RankedTensorType>(resultType));
3326 result.addAttributes(attrs);
3327 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3328 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3329 nofold ? b.getUnitAttr() : UnitAttr());
3330}
3331
3332void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3333 Value source, ArrayRef<OpFoldResult> low,
3334 ArrayRef<OpFoldResult> high, Value constantPadValue,
3335 bool nofold, ArrayRef<NamedAttribute> attrs) {
3336 build(b, result, resultType, source, low, high, nofold, attrs);
3337
3338 // Add a region and a block to yield the pad value.
3339 Region *region = result.regions[0].get();
3340 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3341 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3342 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3343
3344 // `builder.createBlock` changes the insertion point within the block. Create
3345 // a guard to reset the insertion point of the builder after it is destroyed.
3346 OpBuilder::InsertionGuard guard(b);
3347 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3348 tensor::YieldOp::create(b, result.location, constantPadValue);
3349}
3350
3351llvm::SmallBitVector PadOp::getPaddedDims() {
3352 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3353 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3354 for (const auto &en : enumerate(paddingWidths))
3355 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3356 paddedDims.set(en.index());
3357 };
3358 extractPaddedDims(getMixedLowPad());
3359 extractPaddedDims(getMixedHighPad());
3360 return paddedDims;
3361}
3362
3363namespace {
3364// Folds tensor.pad when padding is static zeros and the attribute
3365// doesn't request otherwise.
3366struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3367 using OpRewritePattern<PadOp>::OpRewritePattern;
3368
3369 LogicalResult matchAndRewrite(PadOp padTensorOp,
3370 PatternRewriter &rewriter) const override {
3371 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3372 return failure();
3373 if (padTensorOp.getNofold())
3374 return failure();
3375 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3376 padTensorOp, padTensorOp.getResult().getType(),
3377 padTensorOp.getSource());
3378 return success();
3379 }
3380};
3381
3382// Fold CastOp into PadOp when adding static information.
3383struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3384 using OpRewritePattern<PadOp>::OpRewritePattern;
3385
3386 LogicalResult matchAndRewrite(PadOp padTensorOp,
3387 PatternRewriter &rewriter) const override {
3388 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3389 if (!tensor::canFoldIntoConsumerOp(castOp))
3390 return failure();
3391
3392 auto newResultType = PadOp::inferResultType(
3393 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3394 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3395 padTensorOp.getResultType().getShape());
3396
3397 if (newResultType == padTensorOp.getResultType()) {
3398 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3399 padTensorOp.getSourceMutable().assign(castOp.getSource());
3400 });
3401 } else {
3402 auto newOp = PadOp::create(
3403 rewriter, padTensorOp->getLoc(), newResultType,
3404 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3405 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3406 padTensorOp.getHigh(), padTensorOp.getNofold(),
3407 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3408 IRMapping mapper;
3409 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3410
3411 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3412 padTensorOp, padTensorOp.getResultType(), newOp);
3413 }
3414 return success();
3415 }
3416};
3417
3418// Fold CastOp using the result of PadOp back into the latter if it adds
3419// static information.
3420struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3421 using OpRewritePattern<PadOp>::OpRewritePattern;
3422
3423 LogicalResult matchAndRewrite(PadOp padTensorOp,
3424 PatternRewriter &rewriter) const override {
3425 if (!padTensorOp.getResult().hasOneUse())
3426 return failure();
3427 auto tensorCastOp =
3428 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3429 if (!tensorCastOp)
3430 return failure();
3431 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3432 tensorCastOp.getDest().getType()))
3433 return failure();
3434
3435 auto replacementOp = PadOp::create(
3436 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3437 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3438 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3439 padTensorOp.getHigh(), padTensorOp.getNofold(),
3440 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3441 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3442
3443 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3444 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3445 return success();
3446 }
3447};
3448
3449/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3450/// different dimensions. The pattern applies if the following preconditions
3451/// hold:
3452/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3453/// 2) the tensor::ExtractSliceOps have only unit-strides,
3454/// 3) the tensor::PadOps perform only high-padding,
3455/// 4) the tensor::PadOps have the same constant padding value,
3456/// 5) the tensor::PadOps do not have common padding dimensions,
3457/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3458/// zero-offset for every dimension.
3459/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3460/// the
3461/// padded source dimensions.
3462///
3463/// Example:
3464///
3465/// ```mlir
3466/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3467/// : tensor<64x64xf32> to tensor<?x64xf32>
3468/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3469/// } : tensor<?x64xf32> to tensor<8x64xf32>
3470/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3471/// : tensor<8x64xf32> to tensor<8x?xf32>
3472/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3473/// } : tensor<8x?xf32> to tensor<8x4xf32>
3474/// ```
3475///
3476/// folds into:
3477///
3478/// ```mlir
3479/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3480/// : tensor<64x64xf32> to tensor<?x?xf32>
3481/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3482/// } : tensor<?x?xf32> to tensor<8x4xf32>
3483/// ```
3484struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3485 using OpRewritePattern<PadOp>::OpRewritePattern;
3486
3487 LogicalResult matchAndRewrite(PadOp padOp,
3488 PatternRewriter &rewriter) const override {
3489 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3490 if (!innerSliceOp)
3491 return failure();
3492 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3493 if (!outerPadOp || outerPadOp.getNofold())
3494 return failure();
3495 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3496 if (!outerSliceOp)
3497 return failure();
3498
3499 // 1) Fail if the chain is rank-reducing.
3500 int64_t rank = padOp.getSourceType().getRank();
3501 if (outerSliceOp.getSourceType().getRank() != rank) {
3502 return rewriter.notifyMatchFailure(padOp,
3503 "cannot fold rank-reducing chain");
3504 }
3505
3506 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3507 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3508 return rewriter.notifyMatchFailure(
3509 padOp, "cannot fold non-unit stride ExtractSliceOps");
3510 }
3511
3512 // 3) Fail if the tensor::PadOps have non-zero low padding.
3513 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3514 return rewriter.notifyMatchFailure(padOp,
3515 "cannot fold PadOps with low padding");
3516 }
3517
3518 // 4) Fail if the tensor::PadOps padding values do not match.
3519 Attribute innerAttr, outerAttr;
3520 Value innerValue = padOp.getConstantPaddingValue();
3521 Value outerValue = outerPadOp.getConstantPaddingValue();
3522 if (!innerValue || !outerValue ||
3523 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3524 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3525 innerAttr != outerAttr) {
3526 return rewriter.notifyMatchFailure(
3527 padOp, "cannot fold PadOps with different padding values");
3528 }
3529
3530 // 5) Fail if a dimension is padded by both tensor::PadOps.
3531 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3532 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3533 if (innerDims.anyCommon(outerDims)) {
3534 return rewriter.notifyMatchFailure(
3535 padOp, "cannot fold PadOps with common padding dimensions");
3536 }
3537
3538 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3539 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3540 // for every dimension, and use the offset the other pair. Fail if no
3541 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3542 // exists.
3543 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3544 for (auto en : enumerate(newOffsets)) {
3545 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3546 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3547 if (!innerDims.test(en.index()) &&
3548 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3549 en.value() = outerOffset;
3550 continue;
3551 }
3552 if (!outerDims.test(en.index()) &&
3553 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3554 en.value() = innerOffset;
3555 continue;
3556 }
3557 return rewriter.notifyMatchFailure(
3558 padOp, "cannot find zero-offset and zero-padding pair");
3559 }
3560
3561 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3562 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3563 // outer tensor::PadOp and fail if the size of the inner
3564 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3565 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3566 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3567 for (auto en : enumerate(newSizes)) {
3568 if (!outerDims.test(en.index()))
3569 continue;
3570 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3571 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3572 assert(ShapedType::isStatic(sourceSize) &&
3573 "expected padded dimension to have a static size");
3574 if (getConstantIntValue(sliceSize) != sourceSize) {
3575 return rewriter.notifyMatchFailure(
3576 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3577 "match the size of the outer padding");
3578 }
3579 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3580 }
3581
3582 // Combine the high paddings of the two tensor::PadOps.
3583 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3584 for (auto en : enumerate(newHighPad)) {
3585 if (innerDims.test(en.index()))
3586 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3587 if (outerDims.test(en.index()))
3588 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3589 }
3590
3591 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3592 // the two paddings in one step.
3593 auto newSliceOp = ExtractSliceOp::create(
3594 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3595 newSizes, innerSliceOp.getMixedStrides());
3596 auto newPadOp = PadOp::create(
3597 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3598 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3599 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3600 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3601 newPadOp.getRegion().begin());
3602 rewriter.replaceOp(padOp, newPadOp.getResult());
3603 return success();
3604 }
3605};
3606
3607struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3608 using OpRewritePattern<PadOp>::OpRewritePattern;
3609
3610 LogicalResult matchAndRewrite(PadOp padTensorOp,
3611 PatternRewriter &rewriter) const override {
3612 Value input = padTensorOp.getSource();
3613 if (!llvm::isa<RankedTensorType>(input.getType()))
3614 return failure();
3615 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3616 auto inputRank = inputDims.size();
3617
3618 auto oldResultType =
3619 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3620 if (!oldResultType)
3621 return failure();
3622
3623 auto outputDims = oldResultType.getShape();
3624
3625 // Extract the static info from the high and low operands.
3626 SmallVector<int64_t> constOperandsLow;
3627 SmallVector<Value> newLows;
3628 for (auto operand : padTensorOp.getLow()) {
3629 APSInt intOp;
3630 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3631 constOperandsLow.push_back(ShapedType::kDynamic);
3632 newLows.push_back(operand);
3633 continue;
3634 }
3635 constOperandsLow.push_back(intOp.getExtValue());
3636 }
3637 SmallVector<int64_t> constOperandsHigh;
3638 SmallVector<Value> newHighs;
3639 for (auto operand : padTensorOp.getHigh()) {
3640 APSInt intOp;
3641 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3642 constOperandsHigh.push_back(ShapedType::kDynamic);
3643 newHighs.push_back(operand);
3644 continue;
3645 }
3646 constOperandsHigh.push_back(intOp.getExtValue());
3647 }
3648
3649 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3650 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3651
3652 // Verify the op is well-formed.
3653 if (inputDims.size() != outputDims.size() ||
3654 inputDims.size() != constLow.size() ||
3655 inputDims.size() != constHigh.size())
3656 return failure();
3657
3658 auto lowCount = 0;
3659 auto highCount = 0;
3660 for (size_t i = 0; i < inputRank; i++) {
3661 if (constLow[i] == ShapedType::kDynamic)
3662 constLow[i] = constOperandsLow[lowCount++];
3663 if (constHigh[i] == ShapedType::kDynamic)
3664 constHigh[i] = constOperandsHigh[highCount++];
3665 }
3666
3667 auto staticLow = ArrayRef<int64_t>(constLow);
3668 auto staticHigh = ArrayRef<int64_t>(constHigh);
3669
3670 // Calculate the output sizes with the static information.
3671 SmallVector<int64_t> newOutDims;
3672 for (size_t i = 0; i < inputRank; i++) {
3673 if (outputDims[i] == ShapedType::kDynamic) {
3674 newOutDims.push_back(
3675 (staticLow[i] == ShapedType::kDynamic ||
3676 staticHigh[i] == ShapedType::kDynamic ||
3677 inputDims[i] == ShapedType::kDynamic
3678 ? ShapedType::kDynamic
3679 : inputDims[i] + staticLow[i] + staticHigh[i]));
3680 } else {
3681 newOutDims.push_back(outputDims[i]);
3682 }
3683 }
3684
3685 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3686 llvm::all_of(newOutDims,
3687 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3688 return failure();
3689
3690 // Rewrite the op using the new static type.
3691 auto newResultType = RankedTensorType::get(
3692 newOutDims, padTensorOp.getType().getElementType());
3693 auto newOp = PadOp::create(
3694 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3695 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3696 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3697
3698 IRMapping mapper;
3699 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3700 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3701 newOp);
3702
3703 return success();
3704 }
3705};
3706
3707/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3708///
3709/// Example:
3710///
3711/// ```mlir
3712/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3713/// tensor.yield %val
3714/// } : tensor<1x2xf32> to tensor<2x5xf32>
3715/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3716/// tensor.yield %val
3717/// } : tensor<1x5xf32> to tensor<5x7xf32>
3718/// ```
3719///
3720/// folds into:
3721///
3722/// ```mlir
3723/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3724/// tensor.yield %val
3725/// } : tensor<1x2xf32> to tensor<5x7xf32>
3726/// ```
3727struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3728 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3729
3730 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3731 PatternRewriter &rewriter) const override {
3732 if (padOp.getNofold()) {
3733 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3734 }
3735
3736 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3737 if (!producerPad || producerPad.getNofold()) {
3738 return rewriter.notifyMatchFailure(
3739 padOp, "producer is not a foldable tensor.pad op");
3740 }
3741
3742 // Fail if the tensor::PadOps padding values do not match.
3743 Value consumerPadValue = padOp.getConstantPaddingValue();
3744 Value producerPadValue = producerPad.getConstantPaddingValue();
3745 if (!consumerPadValue || !producerPadValue ||
3746 consumerPadValue != producerPadValue) {
3747 return rewriter.notifyMatchFailure(
3748 padOp,
3749 "cannot fold PadOps with different or non-constant padding values");
3750 }
3751
3752 Location loc = padOp.getLoc();
3753 AffineExpr d0, d1;
3754 bindDims(rewriter.getContext(), d0, d1);
3755
3756 // Combine the low/high paddings of the two tensor::PadOps.
3757 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3758 ArrayRef<OpFoldResult> producerPaddings) {
3759 SmallVector<OpFoldResult> sumPaddings;
3760 for (auto [consumerIndex, producerIndex] :
3761 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3762 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3763 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3764 }
3765 return sumPaddings;
3766 };
3767
3768 SmallVector<OpFoldResult> newHighPad =
3769 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3770 SmallVector<OpFoldResult> newLowPad =
3771 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3772
3773 auto newPadOp = tensor::PadOp::create(
3774 rewriter, padOp.getLoc(), padOp.getResultType(),
3775 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3776 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3777 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3778 newPadOp.getRegion().begin());
3779 rewriter.replaceOp(padOp, newPadOp.getResult());
3780 return success();
3781 }
3782};
3783
3784} // namespace
3785
3786LogicalResult
3787PadOp::reifyResultShapes(OpBuilder &b,
3788 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3789 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3790 SmallVector<OpFoldResult> lp = getMixedLowPad();
3791 SmallVector<OpFoldResult> hp = getMixedHighPad();
3792 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3793 if (!getType().isDynamicDim(i)) {
3794 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3795 continue;
3796 }
3797 Location loc = getLoc();
3798 Value dim = b.createOrFold<tensor::DimOp>(
3799 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3800
3801 AffineExpr d0, d1, d2;
3802 bindDims(b.getContext(), d0, d1, d2);
3803 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3804 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3805 }
3806 return success();
3807}
3808
3809void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3810 MLIRContext *context) {
3811 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3812 FoldOrthogonalPaddings, FoldStaticPadding,
3813 FoldConsecutiveConstantPadding>(context);
3814}
3815
3816/// Return the padding value of the PadOp if it constant. In this context,
3817/// "constant" means an actual constant or "defined outside of the block".
3818///
3819/// Values are considered constant in three cases:
3820/// - A ConstantLike value.
3821/// - A basic block argument from a different block.
3822/// - A value defined outside of the block.
3823///
3824/// If the padding value is not constant, an empty Value is returned.
3825Value PadOp::getConstantPaddingValue() {
3826 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3827 if (!yieldOp)
3828 return {};
3829 Value padValue = yieldOp.getValue();
3830 // Check if yield value is a constant.
3831 if (matchPattern(padValue, m_Constant()))
3832 return padValue;
3833 // Check if yield value is defined inside the PadOp block.
3834 if (padValue.getParentBlock() == &getRegion().front())
3835 return {};
3836 // Else: Yield value defined outside of the PadOp block.
3837 return padValue;
3838}
3839
3840OpFoldResult PadOp::fold(FoldAdaptor) {
3841 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3842 !getNofold())
3843 return getSource();
3844 return {};
3845}
3846
3847//===----------------------------------------------------------------------===//
3848// ParallelInsertSliceOp
3849//===----------------------------------------------------------------------===//
3850
3851OpResult ParallelInsertSliceOp::getTiedOpResult() {
3852 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3853 for (const auto &it :
3854 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3855 Operation &nextOp = it.value();
3856 if (&nextOp == getOperation())
3857 return parallelCombiningParent.getParentResult(it.index());
3858 }
3859 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3860}
3861
3862// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3863void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3864 Value source, Value dest,
3865 ArrayRef<OpFoldResult> offsets,
3866 ArrayRef<OpFoldResult> sizes,
3867 ArrayRef<OpFoldResult> strides,
3868 ArrayRef<NamedAttribute> attrs) {
3869 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3870 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3871 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3872 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3873 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3874 result.addAttributes(attrs);
3875 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3876 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3877 b.getDenseI64ArrayAttr(staticSizes),
3878 b.getDenseI64ArrayAttr(staticStrides));
3879}
3880
3881/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3882/// packed into a Range vector.
3883void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3884 Value source, Value dest,
3885 ArrayRef<Range> ranges,
3886 ArrayRef<NamedAttribute> attrs) {
3887 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3888 build(b, result, source, dest, offsets, sizes, strides, attrs);
3889}
3890
3891// Build a ParallelInsertSliceOp with dynamic entries.
3892void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3893 Value source, Value dest, ValueRange offsets,
3894 ValueRange sizes, ValueRange strides,
3895 ArrayRef<NamedAttribute> attrs) {
3896 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3897 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3898 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3899 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3900 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3901 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3902 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3903}
3904
3905// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3906// to 0, strides set to 1 and inferred result type.
3907void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3908 Value dest, ArrayRef<OpFoldResult> sizes,
3909 ArrayRef<NamedAttribute> attrs) {
3910 Attribute zeroIdxAttr = b.getIndexAttr(0);
3911 Attribute oneIdxAttr = b.getIndexAttr(1);
3912 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3913 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3914 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3915}
3916
3917LogicalResult ParallelInsertSliceOp::verify() {
3918 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3919 return this->emitError("expected InParallelOpInterface parent, got:")
3920 << *(getOperation()->getParentOp());
3921
3922 // Verify result type against inferred type.
3923 RankedTensorType expectedType;
3925 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3926 getStaticSizes(), getStaticStrides(), &expectedType);
3928 return produceSliceErrorMsg(result, *this, expectedType);
3929
3930 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3931 // to the destination tensor.
3932 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3933 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3934 getStaticStrides(), /*generateErrorMessage=*/true);
3935 if (!boundsResult.isValid)
3936 return getOperation()->emitError(boundsResult.errorMessage);
3937
3938 return success();
3939}
3940
3941void ParallelInsertSliceOp::getCanonicalizationPatterns(
3942 RewritePatternSet &results, MLIRContext *context) {
3943 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3944 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3945 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3946}
3947
3948llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3949 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3950}
3951
3952// ParallelCombiningOpInterface implementation.
3953MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3954 return getDestMutable();
3955}
3956
3957Operation *ParallelInsertSliceOp::getIteratingParent() {
3958 // Return the parent InParallelOpInterface's parent.
3959 if (auto combiningOp =
3960 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3961 return combiningOp->getParentOp();
3962 return nullptr;
3963}
3964
3965//===----------------------------------------------------------------------===//
3966// ScatterOp
3967//===----------------------------------------------------------------------===//
3968
3969void ScatterOp::getAsmResultNames(
3970 function_ref<void(Value, StringRef)> setNameFn) {
3971 setNameFn(getResult(), "scatter");
3972}
3973
3974LogicalResult ScatterOp::verify() {
3975 int64_t destRank = getDestType().getRank();
3976 ArrayRef<int64_t> scatterDims = getScatterDims();
3977 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3978 getIndicesType().getShape(), destRank,
3979 "scatter", "dest")))
3980 return failure();
3981
3982 if (!getUnique())
3983 return emitOpError("requires 'unique' attribute to be set");
3984 // TODO: we could also check statically that there are fewer leading index
3985 // tensor dims than the dest dims. If this is not the case, the unique
3986 // attribute cannot be true.
3987
3988 // Use the GatherOp::inferResultType on the `dest` type and verify the
3989 // expected type matches the source type.
3990 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3991 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3992 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3993 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3994 if (getSourceType() != expectedSourceType &&
3995 getSourceType() != expectedRankReducedSourceType) {
3996 return emitOpError("source type "
3997 "mismatch: "
3998 "expected ")
3999 << expectedSourceType << " or its rank-reduced variant "
4000 << expectedRankReducedSourceType << " (got: " << getSourceType()
4001 << ")";
4002 }
4003
4004 return success();
4005}
4006
4007//===----------------------------------------------------------------------===//
4008// SplatOp
4009//===----------------------------------------------------------------------===//
4010
4011void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4012 Type aggregateType, ValueRange dynamicSizes) {
4013 build(builder, result, aggregateType, element, dynamicSizes);
4014}
4015
4016void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4017 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4018 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4019 build(builder, result, aggregateType, element, dynamicSizes);
4020}
4021
4022void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4023 ArrayRef<OpFoldResult> sizes) {
4024 SmallVector<int64_t> staticShape;
4025 SmallVector<Value> dynamicSizes;
4026 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4027 build(builder, result, element, staticShape, dynamicSizes);
4028}
4029
4030void SplatOp::getAsmResultNames(
4031 function_ref<void(Value, StringRef)> setNameFn) {
4032 setNameFn(getResult(), "splat");
4033}
4034
4035LogicalResult SplatOp::verify() {
4036 if (getType().getNumDynamicDims() != getDynamicSizes().size())
4037 return emitOpError("incorrect number of dynamic sizes, has ")
4038 << getDynamicSizes().size() << ", expected "
4039 << getType().getNumDynamicDims();
4040 return success();
4041}
4042
4043LogicalResult
4044SplatOp::reifyResultShapes(OpBuilder &builder,
4045 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4046 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4047 unsigned ctr = 0;
4048 for (int64_t i = 0; i < getType().getRank(); ++i) {
4049 if (getType().isDynamicDim(i)) {
4050 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4051 } else {
4052 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4053 }
4054 }
4055 return success();
4056}
4057
4058OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4059 auto constOperand = adaptor.getInput();
4060 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4061 return {};
4062
4063 // Do not fold if the splat is not statically shaped
4064 if (!getType().hasStaticShape())
4065 return {};
4066
4067 // SplatElementsAttr::get treats single value for second arg as being a
4068 // splat.
4069 return SplatElementsAttr::get(getType(), {constOperand});
4070}
4071
4072//===----------------------------------------------------------------------===//
4073// Common Canonicalizers and Folders.
4074//===----------------------------------------------------------------------===//
4075static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4076 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4077 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4078 // might need special handling of attached regions.
4079 if (isa<InsertSliceOp>(op.getOperation()) ||
4080 isa<LoopLikeOpInterface>(op.getOperation()))
4081 return false;
4082
4084}
4085
4086/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4087/// the `tensor.cast` has source that is more static than the consuming op.
4088///
4089/// Example:
4090/// ```mlir
4091/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4092/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4093/// ```
4094///
4095/// folds into:
4096///
4097/// ```mlir
4098/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4099/// ```
4100/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4101/// can add the pattern to their canonicalizers.
4103 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4105 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4106
4107 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4108 PatternRewriter &rewriter) const override {
4109
4110 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4111 // for that instead.
4112 if (!foldTensorCastPrecondition(op) ||
4113 isa<linalg::RelayoutOpInterface>(*op))
4114 return failure();
4115
4116 SmallVector<Type> newResultTypes(op->getResultTypes());
4117 SmallVector<Value> newOperands =
4118 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4119
4120 // Clone op
4121 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4122
4123 SmallVector<Value, 4> replacements;
4124 replacements.reserve(newOp->getNumResults());
4125 for (auto [oldResult, newResult] :
4126 llvm::zip(op->getResults(), newOp->getResults())) {
4127 if (newResult.getType() != oldResult.getType()) {
4128 replacements.push_back(tensor::CastOp::create(
4129 rewriter, op->getLoc(), oldResult.getType(), newResult));
4130 } else {
4131 replacements.push_back(newResult);
4132 }
4133 }
4134 rewriter.replaceOp(op, replacements);
4135
4136 return success();
4137 }
4138};
4139
4140//===----------------------------------------------------------------------===//
4141// TensorDialect
4142//===----------------------------------------------------------------------===//
4143
4144void TensorDialect::getCanonicalizationPatterns(
4145 RewritePatternSet &results) const {
4146 results.add<FoldTensorCastProducerOp>(getContext());
4147}
4148
4149//===----------------------------------------------------------------------===//
4150// TableGen'd op method definitions
4151//===----------------------------------------------------------------------===//
4152
4153#define GET_OP_CLASSES
4154#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:75
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:90
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.