MLIR 22.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/Builders.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
39#include <optional>
40
41using namespace mlir;
42using namespace mlir::tensor;
43
44/// Materialize a single constant operation from a given attribute value with
45/// the desired resultant type.
46Operation *TensorDialect::materializeConstant(OpBuilder &builder,
47 Attribute value, Type type,
48 Location loc) {
49 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 return op;
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
54 return nullptr;
55}
56
58 int64_t dim) {
59 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62
63 return builder.getIndexAttr(tensorType.getDimSize(dim));
64}
65
67 Location loc, Value value) {
68 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
71 result.push_back(getMixedSize(builder, loc, value, i));
72 return result;
73}
74
76 OpResult opResult) {
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78 assert(tensorType && "expected tensor type");
79
80 // If the op has a destination, it implements DestinationStyleOpInterface and
81 // we can query the destination operand from that interface.
82 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83 if (destOp)
84 return destOp.getTiedOpOperand(opResult)->get();
85
86 // Otherwise, create a new destination tensor with the same shape.
88 b.setInsertionPoint(opResult.getDefiningOp());
89
90 // Compute sizes.
92 if (!tensorType.hasStaticShape()) {
93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94 ReifiedRankedShapedTypeDims reifiedShapes;
95 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96 return failure();
97 mixedSizes = reifiedShapes[opResult.getResultNumber()];
98 } else {
99 // Static shape: Take static sizes directly.
100 for (int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(b.getIndexAttr(sz));
102 }
103
104 // Create empty tensor.
105 Value emptyTensor =
106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
107 return emptyTensor;
108}
109
111 Operation *op,
113 for (OpResult opResult : op->getResults()) {
114 if (llvm::isa<TensorType>(opResult.getType())) {
115 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116 if (failed(destination))
117 return failure();
118 result.push_back(*destination);
119 }
120 }
121 return success();
122}
123
125 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
129 return false;
130 }
131 return tp1 == tp2; // default implementation
132}
133
134/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135/// rank-extending tensor.insert_slice op.
136static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137 ArrayRef<OpFoldResult> mixedSizes) {
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
140
141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
143 // Rank-reduced dims must have a static unit dimension.
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
147
148 if (shapePos < 0) {
149 // There are no more dims in the reduced shape. All remaining sizes must
150 // be rank-reduced dims.
151 assert(isStaticUnitSize && "expected unit dim");
152 droppedDims.set(idx);
153 continue;
154 }
155
156 // Dim is preserved if the size is not a static 1.
157 if (!isStaticUnitSize) {
158 --shapePos;
159 continue;
160 }
161
162 // Dim is preserved if the reduced shape dim is also 1.
163 if (reducedShape[shapePos] == 1) {
164 --shapePos;
165 continue;
166 }
167
168 // Otherwise: Dim is dropped.
169 droppedDims.set(idx);
170 }
171
172 assert(shapePos < 0 && "dimension mismatch");
173 return droppedDims;
174}
175
176/// Given a ranked tensor type and a range of values that defines its dynamic
177/// dimension sizes, turn all dynamic sizes that have a constant value into
178/// static dimension sizes.
179static RankedTensorType
180foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181 SmallVector<Value> &foldedDynamicSizes) {
182 SmallVector<int64_t> staticShape(type.getShape());
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
185
186 // Compute new static and dynamic sizes.
187 unsigned ctr = 0;
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
191 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192 if (cst.has_value()) {
193 // Dynamic size must be non-negative.
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
196 continue;
197 }
198 staticShape[i] = *cst;
199 } else {
200 foldedDynamicSizes.push_back(dynamicSize);
201 }
202 }
203 }
204
205 return RankedTensorType::get(staticShape, type.getElementType(),
206 type.getEncoding());
207}
208
209//===----------------------------------------------------------------------===//
210// BitcastOp
211//===----------------------------------------------------------------------===//
212
213bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214 if (inputs.size() != 1 || outputs.size() != 1)
215 return false;
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
219 if (!aT || !bT)
220 return false;
221
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223 return false;
224
225 return succeeded(verifyCompatibleShape(aT, bT));
226}
227
228namespace {
229
230/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231/// operation.
232struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter) const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
240 return failure();
241
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
245 return success();
246 }
247};
248
249} // namespace
250
251void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252 MLIRContext *context) {
253 results.add<ChainedTensorBitcast>(context);
254}
255
256//===----------------------------------------------------------------------===//
257// CastOp
258//===----------------------------------------------------------------------===//
259
260void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261 setNameFn(getResult(), "cast");
262}
263
264/// Returns true if `target` is a ranked tensor type that preserves static
265/// information available in the `source` ranked tensor type.
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269
270 // Requires RankedTensorType.
271 if (!sourceType || !targetType)
272 return false;
273
274 // Requires same elemental type.
275 if (sourceType.getElementType() != targetType.getElementType())
276 return false;
277
278 // Requires same rank.
279 if (sourceType.getRank() != targetType.getRank())
280 return false;
281
282 // Requires same encoding.
283 if (sourceType.getEncoding() != targetType.getEncoding())
284 return false;
285
286 // If cast is towards more static sizes along any dimension, don't fold.
287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
290 return false;
291 }
292
293 return true;
294}
295
296/// Determines whether tensor::CastOp casts to a more dynamic version of the
297/// source tensor. This is useful to fold a tensor.cast into a consuming op and
298/// implement canonicalization patterns for ops in different dialects that may
299/// consume the results of tensor.cast operations. Such foldable tensor.cast
300/// operations are typically inserted as `slice` ops and are canonicalized,
301/// to preserve the type compatibility of their uses.
302///
303/// Returns true when all conditions are met:
304/// 1. source and result are ranked tensors with same element type and rank.
305/// 2. the tensor type has more static information than the result
306///
307/// Example:
308/// ```mlir
309/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310/// %2 = consumer %1 ... : tensor<?x?xf32> ...
311/// ```
312///
313/// folds into:
314///
315/// ```mlir
316/// %2 = consumer %0 ... : tensor<8x16xf32> ...
317/// ```
319 if (!castOp)
320 return false;
321
322 // Can fold if the source of cast has at least as much static information as
323 // its results.
324 return preservesStaticInformation(castOp.getType(),
325 castOp.getSource().getType());
326}
327
328/// Determines whether the tensor::CastOp casts to a more static version of the
329/// source tensor. This is useful to fold into a producing op and implement
330/// canonicalization patterns with the `tensor.cast` op as the root, but
331/// producer being from different dialects. Returns true when all conditions are
332/// met:
333/// 1. source and result and ranked tensors with same element type and rank.
334/// 2. the result type has more static information than the source.
335///
336/// Example:
337/// ```mlir
338/// %1 = producer ... : tensor<?x?xf32>
339/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340/// ```
341///
342/// can be canonicalized to :
343///
344/// ```mlir
345/// %2 = producer ... : tensor<8x16xf32>
346/// ```
347/// Not all ops might be canonicalizable this way, but for those that can be,
348/// this method provides a check that it is worth doing the canonicalization.
350 if (!castOp)
351 return false;
352 return preservesStaticInformation(castOp.getSource().getType(),
353 castOp.getType());
354}
355
357 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
358 if (llvm::isa<BlockArgument>(opOperand.get()))
359 return false;
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
362 });
363}
364
366 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
367 SmallVector<Value> newOperands;
368 newOperands.reserve(op->getNumOperands());
369
370 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
371
372 // Assumes that the result has dpsInits followed by nonDpsInits.
373 int64_t dpsInitIdx = 0;
374 for (OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
376 bool fold = canFoldIntoConsumerOp(tensorCastOp);
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
381 }
382 return newOperands;
383}
384
385/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
386/// that can be folded.
388 bool folded = false;
389 for (OpOperand &operand : op->getOpOperands()) {
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
391 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
392 operand.set(castOp.getOperand());
393 folded = true;
394 }
395 }
396 return success(folded);
397}
398
399bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
400 if (inputs.size() != 1 || outputs.size() != 1)
401 return false;
402 Type a = inputs.front(), b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(b);
405 if (!aT || !bT)
406 return false;
407
408 if (aT.getElementType() != bT.getElementType())
409 return false;
410
411 return succeeded(verifyCompatibleShape(aT, bT));
412}
413
414/// Compute a TensorType that has the joined shape knowledge of the two
415/// given TensorTypes. The element types need to match.
417 assert(one.getElementType() == two.getElementType());
418
419 if (!one.hasRank())
420 return two;
421 if (!two.hasRank())
422 return one;
423
424 int64_t rank = one.getRank();
425 if (rank != two.getRank())
426 return {};
427
429 join.reserve(rank);
430 for (int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
433 continue;
434 }
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
437 continue;
438 }
439 if (one.getDimSize(i) != two.getDimSize(i))
440 return {};
441 join.push_back(one.getDimSize(i));
442 }
443 return RankedTensorType::get(join, one.getElementType());
444}
445
446namespace {
447
448/// Replaces chains of two tensor.cast operations by a single tensor.cast
449/// operation if doing so does not remove runtime constraints.
450struct ChainedTensorCast : public OpRewritePattern<CastOp> {
451 using OpRewritePattern<CastOp>::OpRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter) const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
456
457 if (!tensorCastOperand)
458 return failure();
459
460 auto sourceType =
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
464
465 // We can remove the intermediate cast if joining all three produces the
466 // same result as just joining the source and result shapes.
467 auto firstJoin =
468 joinShapes(joinShapes(sourceType, intermediateType), resultType);
469
470 // The join might not exist if the cast sequence would fail at runtime.
471 if (!firstJoin)
472 return failure();
473
474 // The newJoin always exists if the above join exists, it might just contain
475 // less information. If so, we cannot drop the intermediate cast, as doing
476 // so would remove runtime checks.
477 auto newJoin = joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
479 return failure();
480
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
483 return success();
484 }
485};
486
487/// Fold tensor.cast into tesor.extract_slice producer.
488/// Example:
489/// ```
490/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
491/// tensor<128x512xf32> to tensor<?x512xf32>
492/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
493/// ```
494/// ->
495/// ```
496/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
497/// tensor<128x512xf32> to tensor<16x512xf32>
498/// ```
499struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
500 using OpRewritePattern<CastOp>::OpRewritePattern;
501
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter) const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
506
507 // Cannot fold cast to unranked tensor.
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
511 return failure();
512
513 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
516 .getShape())
517 return failure();
518
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
520 auto dimMask = computeRankReductionMask(
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
522 size_t dimIndex = 0;
523 for (size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
525 continue;
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
528 continue;
529 sizes[i] = rewriter.getIndexAttr(dim);
530 }
531
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
536 return success();
537 }
538};
539
540} // namespace
541
542void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
543 MLIRContext *context) {
544 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
545}
546
547//===----------------------------------------------------------------------===//
548// ConcatOp
549//===----------------------------------------------------------------------===//
550
551RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553 auto tensorTypes =
554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
555 int64_t concatRank = tensorTypes[0].getRank();
556
557 // The concatenation dim must be in the range [0, rank).
558 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
559
560 SmallVector<int64_t> sizes(concatRank);
561 for (int64_t i = 0, e = concatRank; i < e; ++i) {
562 if (i == dim)
563 continue;
564 SaturatedInteger size;
565 for (auto tensorType : tensorTypes)
566 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
567 sizes[i] = size.asInteger();
568 }
569 auto concatSize = SaturatedInteger::wrap(0);
570 for (auto tensorType : tensorTypes)
571 concatSize =
572 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
573 sizes[dim] = concatSize.asInteger();
574 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
575}
576
577void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
578 ValueRange inputs) {
579 FailureOr<RankedTensorType> resultType =
580 inferResultType(dim, inputs.getTypes());
581 assert(succeeded(resultType) && "failed to infer concatenation result type");
582 build(builder, result, *resultType, dim, inputs);
583}
584
585LogicalResult ConcatOp::verify() {
586 if (getInputs().size() < 1)
587 return emitOpError("requires at least one input");
588
590 for (auto input : getInputs())
591 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
592
593 RankedTensorType resultType = getResultType();
594 int64_t resultRank = getRank();
595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
596 return type.getRank() != resultRank;
597 }))
598 return emitOpError("rank of concatenated inputs must match result rank");
599
600 Type resultElementType = resultType.getElementType();
601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
602 return type.getElementType() != resultElementType;
603 }))
604 return emitOpError("inputs and result element type must match");
605
606 int64_t dim = getDim();
607 if (dim >= resultRank)
608 return emitOpError("concatenation dim must be less than the tensor rank");
609
610 SmallVector<int64_t> sizes(resultRank);
611 for (int64_t i = 0, e = resultRank; i < e; ++i) {
612 if (i == dim)
613 continue;
614 SaturatedInteger size;
615 for (auto tensorType : inputTypes) {
616 FailureOr<SaturatedInteger> maybeSize =
617 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
618 if (failed(maybeSize))
619 return emitOpError("static concatenation size mismatch along ")
620 << "non-concatenated dimension " << i;
621 size = *maybeSize;
622 }
623 sizes[i] = size.asInteger();
624 }
625 auto concatSize = SaturatedInteger::wrap(0);
626 for (auto tensorType : inputTypes)
627 concatSize =
628 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
629 sizes[dim] = concatSize.asInteger();
630 auto inferredResultType =
631 RankedTensorType::get(sizes, inputTypes[0].getElementType());
632
633 for (auto [inferredSize, actualSize] :
634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
636 ShapedType::isDynamic(actualSize);
637 if (!hasDynamic && inferredSize != actualSize)
638 return emitOpError("result type ")
639 << resultType << "does not match inferred shape "
640 << inferredResultType << " static sizes";
641 }
642
643 return success();
644}
645
646FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
647 size_t numInputs = getInputs().size();
648 uint64_t concatDim = getDim();
649
651 inputShapes.reserve(numInputs);
652 SmallVector<OpFoldResult> concatOffsets;
653 concatOffsets.reserve(numInputs);
654 SmallVector<OpFoldResult> outputShape;
655
656 AffineExpr addExpr =
657 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
658 OpFoldResult zero = builder.getIndexAttr(0);
659 Location loc = getLoc();
660 for (auto [index, input] : llvm::enumerate(getInputs())) {
661 SmallVector<OpFoldResult> inputShape =
662 tensor::getMixedSizes(builder, input.getLoc(), input);
663 if (index == 0) {
664 outputShape = inputShape;
665 concatOffsets.push_back(zero);
666 } else {
667 concatOffsets.push_back(outputShape[concatDim]);
668 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
669 builder, loc, addExpr,
670 {outputShape[concatDim], inputShape[concatDim]});
671 }
672 inputShapes.emplace_back(std::move(inputShape));
673 }
674
675 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
677
678 int64_t rank = getType().getRank();
679 OpFoldResult one = builder.getIndexAttr(1);
680 SmallVector<OpFoldResult> strides(rank, one);
681 SmallVector<OpFoldResult> offsets(rank, zero);
682 for (auto [index, input] : llvm::enumerate(getInputs())) {
683 offsets[concatDim] = concatOffsets[index];
684 auto insertSlice = tensor::InsertSliceOp::create(
685 builder, loc, input, replacement, offsets, inputShapes[index], strides);
686 replacement = insertSlice.getResult();
687 }
688 if (replacement.getType() != getType()) {
689 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
690 }
692}
693
694LogicalResult
695ConcatOp::reifyResultShapes(OpBuilder &builder,
696 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
697 ValueRange inputs = getInputs();
698 int64_t dim = getDim();
699 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
700
701 Value init = inputs[0];
702 int64_t rank = getType().getRank();
703
704 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
705
706 // Pre-populate the result sizes with as much static information as possible
707 // from the given result type, as well as the inferred result type, otherwise
708 // use the dim sizes from the first input.
709 for (int64_t i = 0; i < rank; ++i) {
710 if (i == dim)
711 continue;
712 if (!getType().isDynamicDim(i)) {
713 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
714 } else if (!inferredResultType.isDynamicDim(i)) {
715 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
716 builder, getLoc(),
717 builder.getIndexAttr(inferredResultType.getDimSize(i)));
718 } else {
719 reifiedReturnShapes[0][i] =
720 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
721 }
722 }
723
724 if (getType().isDynamicDim(dim)) {
725 // Take the sum of the input sizes along the concatenated dim.
726 AffineExpr sum = builder.getAffineDimExpr(0);
728 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
729 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
730 sum = sum + builder.getAffineDimExpr(idx + 1);
731 sizes.push_back(
732 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
733 }
734 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
735 builder, getLoc(),
736 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
737 } else {
738 // If the result shape is static along the concatenated dim, use the static
739 // shape.
740 reifiedReturnShapes[0][dim] =
741 builder.getIndexAttr(getType().getDimSize(dim));
742 }
743 return success();
744}
745
746void ConcatOp::getAsmResultNames(
747 function_ref<void(Value, StringRef)> setNameFn) {
748 setNameFn(getResult(), "concat");
749}
750
751OpFoldResult ConcatOp::fold(FoldAdaptor) {
752 ValueRange inputs = getInputs();
753 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
754 return inputs[0];
755 return {};
756}
757
758namespace {
759/// Fold a concat op with a single input to a cast.
760struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
761 using OpRewritePattern<ConcatOp>::OpRewritePattern;
762
763 LogicalResult matchAndRewrite(ConcatOp concatOp,
764 PatternRewriter &rewriter) const override {
765 if (concatOp.getInputs().size() != 1)
766 return failure();
767 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
768 concatOp.getInputs()[0]);
769 return success();
770 }
771};
772
773/// Propagate static shapes into the operands of a `tensor.concat`.
774///
775/// `tensor.concat` requires every operand to match on all dimensions except the
776/// concatenation dimension. If one operand is already static in those
777/// dimensions, the other operands may safely be refined to that same static
778/// shape.
779///
780/// Example:
781///
782/// ```mlir
783/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
784/// tensor<?x12xi32>
785/// ```
786/// ->
787/// ```mlir
788/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
789/// %2 = tensor.concat dim(0) %0, %cast :
790/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
791/// ```
792struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
793 using OpRewritePattern<ConcatOp>::OpRewritePattern;
794
795 LogicalResult matchAndRewrite(ConcatOp concatOp,
796 PatternRewriter &rewriter) const override {
797 int64_t dim = concatOp.getDim();
798 RankedTensorType inferredResultType =
799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
800
801 // Find operands for which a more static shape can be inferred.
802 LogicalResult matched = failure();
803 // Inferred operand shapes are identical in every dimension except the
804 // concatenation dimension.
805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
806 for (auto [operandIdx, operandType] :
807 llvm::enumerate(concatOp->getOperandTypes())) {
808 // Compute inferred type for operand.
809 inferredOperandShape[dim] =
810 cast<RankedTensorType>(operandType).getDimSize(dim);
811 auto inferredOperandType = RankedTensorType::get(
812 inferredOperandShape, inferredResultType.getElementType());
813
814 // Check if inferred type is more static.
815 if (!preservesStaticInformation(inferredOperandType, operandType)) {
816 matched = success();
817
818 // Use refined operand type and create cast from original operand.
819 auto castOp =
820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
821 concatOp.getOperand(operandIdx));
822 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
823 concatOp->setOperand(operandIdx, castOp->getResult(0));
824 });
825 }
826 }
827
828 return matched;
829 }
830};
831
832// Ensure `tensor.concat`'s result type is at least as static as can be inferred
833// from its operand types.
834///
835/// Example:
836/// ```mlir
837/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
838/// tensor<?x?xi32>
839/// ```
840/// ->
841/// ```mlir
842/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
843/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
844/// tensor<?x?xi32>
845/// ```
846struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
847 using OpRewritePattern<ConcatOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(ConcatOp concatOp,
850 PatternRewriter &rewriter) const override {
851 int64_t dim = concatOp.getDim();
852 RankedTensorType inferredResultType =
853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
854
855 // The result type should be at least as static as inferred result type.
856 if (preservesStaticInformation(inferredResultType,
857 concatOp.getResultType())) {
858 return failure();
859 }
860
861 auto newConcatOp =
862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
863 concatOp->getOperands());
864 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
865 newConcatOp);
866
867 return success();
868 }
869};
870} // namespace
871
872void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
873 MLIRContext *context) {
874 results
875 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
876 context);
877}
878
879//===----------------------------------------------------------------------===//
880// DimOp
881//===----------------------------------------------------------------------===//
882
883void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
884 setNameFn(getResult(), "dim");
885}
886
887void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
888 int64_t index) {
889 auto loc = result.location;
890 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
891 build(builder, result, source, indexValue);
892}
893
894std::optional<int64_t> DimOp::getConstantIndex() {
896}
897
898Speculation::Speculatability DimOp::getSpeculatability() {
899 auto constantIndex = getConstantIndex();
900 if (!constantIndex)
902
903 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
904 if (!rankedSourceType)
906
907 if (rankedSourceType.getRank() <= constantIndex)
909
911}
912
913void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
914 SetIntLatticeFn setResultRange) {
915 setResultRange(getResult(),
916 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
917}
918
919OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
920 // All forms of folding require a known index.
921 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
922 if (!index)
923 return {};
924
925 // Folding for unranked types (UnrankedTensorType) is not supported.
926 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
927 if (!tensorType)
928 return {};
929
930 // Out of bound indices produce undefined behavior but are still valid IR.
931 // Don't choke on them.
932 int64_t indexVal = index.getInt();
933 if (indexVal < 0 || indexVal >= tensorType.getRank())
934 return {};
935
936 // Fold if the shape extent along the given index is known.
937 if (!tensorType.isDynamicDim(index.getInt())) {
938 Builder builder(getContext());
939 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
940 }
941
942 Operation *definingOp = getSource().getDefiningOp();
943
944 // Fold dim to the operand of tensor.generate.
945 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
946 auto resultType =
947 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
948 // The case where the type encodes the size of the dimension is handled
949 // above.
950 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
951
952 // Find the operand of the fromElements that corresponds to this index.
953 auto dynExtents = fromElements.getDynamicExtents().begin();
954 for (auto dim : resultType.getShape().take_front(index.getInt()))
955 if (ShapedType::isDynamic(dim))
956 dynExtents++;
957
958 return Value{*dynExtents};
959 }
960
961 // The size at the given index is now known to be a dynamic size.
962 unsigned unsignedIndex = index.getValue().getZExtValue();
963
964 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
965 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
966 // `resolve-shaped-type-result-dims` pass.
967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
968 sliceOp.isDynamicSize(unsignedIndex)) {
969 return {sliceOp.getDynamicSize(unsignedIndex)};
970 }
971 }
972
973 // dim(cast) -> dim
974 if (succeeded(foldTensorCast(*this)))
975 return getResult();
976
977 return {};
978}
979
980namespace {
981/// Fold dim of a cast into the dim of the source of the tensor cast.
982struct DimOfCastOp : public OpRewritePattern<DimOp> {
983 using OpRewritePattern<DimOp>::OpRewritePattern;
984
985 LogicalResult matchAndRewrite(DimOp dimOp,
986 PatternRewriter &rewriter) const override {
987 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
988 if (!castOp)
989 return failure();
990 Value newSource = castOp.getOperand();
991 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
992 return success();
993 }
994};
995
996/// Fold dim of a destination passing style op into the dim of the corresponding
997/// init.
998struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
999 using OpRewritePattern<DimOp>::OpRewritePattern;
1000
1001 LogicalResult matchAndRewrite(DimOp dimOp,
1002 PatternRewriter &rewriter) const override {
1003 auto source = dimOp.getSource();
1004 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1005 if (!destOp)
1006 return failure();
1007
1008 auto resultIndex = cast<OpResult>(source).getResultNumber();
1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1010
1011 rewriter.modifyOpInPlace(
1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1013 return success();
1014 }
1015};
1016
1017/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1018/// operand.
1019struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1020 using OpRewritePattern<DimOp>::OpRewritePattern;
1021
1022 LogicalResult matchAndRewrite(DimOp dim,
1023 PatternRewriter &rewriter) const override {
1024 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1025
1026 if (!reshape)
1027 return failure();
1028
1029 // Since tensors are immutable we don't need to worry about where to place
1030 // the extract call
1031 rewriter.setInsertionPointAfter(dim);
1032 Location loc = dim.getLoc();
1033 Value extract =
1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1035 if (extract.getType() != dim.getType())
1036 extract =
1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1038 rewriter.replaceOp(dim, extract);
1039 return success();
1040 }
1041};
1042} // namespace
1043
1044void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1045 MLIRContext *context) {
1046 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// EmptyOp
1051//===----------------------------------------------------------------------===//
1052
1053void EmptyOp::build(OpBuilder &builder, OperationState &result,
1054 ArrayRef<int64_t> staticShape, Type elementType,
1055 Attribute encoding) {
1056 assert(none_of(staticShape, ShapedType::isDynamic) &&
1057 "expected only static sizes");
1058 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1059}
1060
1061void EmptyOp::build(OpBuilder &builder, OperationState &result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 ValueRange dynamicSizes, Attribute encoding) {
1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1065 build(builder, result, tensorType, dynamicSizes);
1066}
1067
1068void EmptyOp::build(OpBuilder &builder, OperationState &result,
1069 ArrayRef<OpFoldResult> sizes, Type elementType,
1070 Attribute encoding) {
1071 SmallVector<int64_t> staticShape;
1072 SmallVector<Value> dynamicSizes;
1073 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1074 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1075}
1076
1077LogicalResult EmptyOp::verify() {
1078 if (getType().getNumDynamicDims() != getDynamicSizes().size())
1079 return emitOpError("incorrect number of dynamic sizes, has ")
1080 << getDynamicSizes().size() << ", expected "
1081 << getType().getNumDynamicDims();
1082 return success();
1083}
1084
1085LogicalResult
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1087 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1089 unsigned ctr = 0;
1090 for (int64_t i = 0; i < getType().getRank(); ++i) {
1091 if (getType().isDynamicDim(i)) {
1092 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1093 } else {
1094 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1095 }
1096 }
1097 return success();
1098}
1099
1100Value EmptyOp::getDynamicSize(unsigned idx) {
1101 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1102 unsigned ctr = 0;
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (getType().isDynamicDim(i))
1105 ++ctr;
1106 return getDynamicSizes()[ctr];
1107}
1108
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult> result;
1111 unsigned ctr = 0;
1112 OpBuilder b(getContext());
1113 for (int64_t i = 0; i < getType().getRank(); ++i) {
1114 if (getType().isDynamicDim(i)) {
1115 result.push_back(getDynamicSizes()[ctr++]);
1116 } else {
1117 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1118 }
1119 }
1120 return result;
1121}
1122
1123namespace {
1124/// Change the type of the result of a `tensor.empty` by making the result
1125/// type statically sized along dimensions that in the original operation were
1126/// defined as dynamic, but the size was defined using a `constant` op. For
1127/// example
1128///
1129/// %c5 = arith.constant 5: index
1130/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1131///
1132/// to
1133///
1134/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1137
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter) const override {
1140 SmallVector<Value> foldedDynamicSizes;
1141 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1143
1144 // Stop here if no dynamic size was promoted to static.
1145 if (foldedTensorType == op.getType())
1146 return failure();
1147
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1150 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1151 return success();
1152 }
1153};
1154
1155struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1157
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter) const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1163 return failure();
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1168 return failure();
1169 rewriter.replaceOp(dimOp,
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1171 return success();
1172 }
1173};
1174
1175/// Canonicalize
1176///
1177/// ```mlir
1178/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1179/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1180/// ```
1181///
1182/// into
1183///
1184/// ```mlir
1185/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1186/// ```
1187///
1188/// This assumes the input program is correct in terms of its shape. So it is
1189/// safe to assume that `%d0` is in fact 4.
1190struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1192
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter) const override {
1195 if (!canFoldIntoProducerOp(castOp))
1196 return failure();
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1198 if (!producer)
1199 return failure();
1200
1201 auto resultType =
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1210 int64_t newDim = std::get<0>(it);
1211 OpFoldResult currDim = std::get<1>(it);
1212 // Case 1: The empty tensor dim is static. Check that the tensor cast
1213 // result dim matches.
1214 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1215 if (ShapedType::isDynamic(newDim) ||
1216 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1217 // Something is off, the cast result shape cannot be more dynamic
1218 // than the empty tensor result shape (enforced by
1219 // `canFoldIntoProducer`). Abort for now.
1220 return rewriter.notifyMatchFailure(
1221 producer, "mismatch in static value of shape of empty tensor "
1222 "result and cast result");
1223 }
1224 newMixedSizes.push_back(attr);
1225 continue;
1226 }
1227
1228 // Case 2 : The tensor cast shape is static, but empty tensor result
1229 // shape is dynamic.
1230 if (ShapedType::isStatic(newDim)) {
1231 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1232 continue;
1233 }
1234
1235 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1236 // shape is dynamic. Use the dynamic value from the empty tensor op.
1237 newMixedSizes.push_back(currDim);
1238 }
1239
1240 // TODO: Do not drop tensor encoding.
1241 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1242 resultType.getElementType());
1243 return success();
1244 }
1245};
1246
1247} // namespace
1248
1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1252 ReplaceEmptyTensorStaticShapeDims>(context);
1253}
1254
1255//===----------------------------------------------------------------------===//
1256// ExtractOp
1257//===----------------------------------------------------------------------===//
1258
1259namespace {
1260
1261/// Canonicalizes the pattern of the form
1262///
1263/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1264/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1265///
1266/// to
1267///
1268/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1269struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1270 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1271
1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1273 PatternRewriter &rewriter) const final {
1274 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1275 if (!tensorCast)
1276 return failure();
1277 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1278 return failure();
1279 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1280 extract, tensorCast.getSource(), extract.getIndices());
1281 return success();
1282 }
1283};
1284
1285/// Canonicalizes the pattern of the form
1286///
1287/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1288/// tensor<12xf64>
1289/// %extracted_element = tensor.extract %val[%c10] :
1290/// tensor<12xf64>
1291///
1292/// to
1293///
1294/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1295struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1296 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1297
1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1299 PatternRewriter &rewriter) const final {
1300 auto collapseOp =
1301 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1302 if (!collapseOp)
1303 return failure();
1304 if (!collapseOp.getSrcType().hasStaticShape())
1305 return failure();
1306
1307 auto sourceSizes = collapseOp.getSrcType().getShape();
1308
1309 SmallVector<Value> indices(extractOp.getIndices().begin(),
1310 extractOp.getIndices().end());
1311 SmallVector<Value> sourceIndices;
1312 for (auto [index, group] :
1313 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1314 assert(!group.empty() && "association indices groups cannot be empty");
1315 auto groupSize = group.size();
1316
1317 if (groupSize == 1) {
1318 sourceIndices.push_back(index);
1319 continue;
1320 }
1321
1322 SmallVector<int64_t> basis =
1323 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1325 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1326 llvm::append_range(sourceIndices, delinearize.getResults());
1327 }
1328 if (collapseOp.getReassociationIndices().empty()) {
1329 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1330 int64_t srcRank =
1331 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1332 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1333 rewriter, extractOp.getLoc(), zeroAffineMap,
1334 ArrayRef<OpFoldResult>{});
1335 for (int64_t i = 0; i < srcRank; i++) {
1336 sourceIndices.push_back(
1337 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1338 }
1339 }
1340
1341 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1342 extractOp, collapseOp.getSrc(), sourceIndices);
1343 return success();
1344 }
1345};
1346
1347} // namespace
1348
1349void ExtractOp::getAsmResultNames(
1350 function_ref<void(Value, StringRef)> setNameFn) {
1351 setNameFn(getResult(), "extracted");
1352}
1353
1354LogicalResult ExtractOp::verify() {
1355 // Verify the # indices match if we have a ranked type.
1356 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1357 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1358 return emitOpError("incorrect number of indices for extract_element");
1359 return success();
1360}
1361
1362/// If we have an ExtractOp consuming an InsertOp with the same
1363/// indices, we can return the InsertOp's scalar directly.
1364// TODO: This only checks the immediate producer; extend to go up the
1365// insert/extract chain if the slices are disjoint.
1366static Value foldExtractAfterInsert(ExtractOp extractOp) {
1367 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1368
1369 auto isSame = [](Value a, Value b) {
1371 };
1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1374 return insertOp.getScalar();
1375
1376 return {};
1377}
1378
1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1380 if (Attribute tensor = adaptor.getTensor()) {
1381 // If this is a splat elements attribute, simply return the value.
1382 // All of the elements of a splat attribute are the same.
1383 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1384 return splatTensor.getSplatValue<Attribute>();
1385
1386 // If this is a dense resource elements attribute, return.
1387 if (isa<DenseResourceElementsAttr>(tensor))
1388 return {};
1389 }
1390
1391 // Collect the constant indices into the tensor.
1392 SmallVector<uint64_t, 8> indices;
1393 for (Attribute indice : adaptor.getIndices()) {
1394 if (!indice || !llvm::isa<IntegerAttr>(indice))
1395 return {};
1396 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1397 }
1398
1399 // Fold extract(from_elements(...)).
1400 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1401 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1402 auto rank = tensorType.getRank();
1403 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1404 "rank mismatch");
1405 int flatIndex = 0;
1406 int stride = 1;
1407 for (int i = rank - 1; i >= 0; --i) {
1408 flatIndex += indices[i] * stride;
1409 stride *= tensorType.getDimSize(i);
1410 }
1411 // Prevent out of bounds accesses. This can happen in invalid code that
1412 // will never execute.
1413 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1414 flatIndex < 0)
1415 return {};
1416 return fromElementsOp.getElements()[flatIndex];
1417 }
1418
1419 // If this is an elements attribute, query the value at the given indices.
1420 if (Attribute tensor = adaptor.getTensor()) {
1421 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1422 if (elementsAttr && elementsAttr.isValidIndex(indices))
1423 return elementsAttr.getValues<Attribute>()[indices];
1424 }
1425
1426 if (Value result = foldExtractAfterInsert(*this))
1427 return result;
1428
1429 return {};
1430}
1431
1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.add<ExtractFromTensorCast>(context);
1435}
1436
1439 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1440}
1441
1442//===----------------------------------------------------------------------===//
1443// FromElementsOp
1444//===----------------------------------------------------------------------===//
1445
1446void FromElementsOp::getAsmResultNames(
1447 function_ref<void(Value, StringRef)> setNameFn) {
1448 setNameFn(getResult(), "from_elements");
1449}
1450
1451void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1452 ValueRange elements) {
1453 assert(!elements.empty() && "expected at least one element");
1454 Type resultType = RankedTensorType::get(
1455 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1456 build(builder, result, resultType, elements);
1457}
1458
1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1460 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1461 return DenseElementsAttr::get(getType(), adaptor.getElements());
1462 return {};
1463}
1464
1465namespace {
1466
1467// Pushes the index_casts that occur before extractions to after the extract.
1468// This minimizes type conversion in some cases and enables the extract
1469// canonicalizer. This changes:
1470//
1471// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1472// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1473//
1474// to the following:
1475//
1476// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1477// %cast = arith.index_cast %extract : i32 to index
1478//
1479// to just %element.
1480//
1481// Consider expanding this to a template and handle all tensor cast
1482// operations.
1483struct ExtractElementFromIndexCast
1484 : public OpRewritePattern<tensor::ExtractOp> {
1485 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1486
1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1488 PatternRewriter &rewriter) const final {
1489 Location loc = extract.getLoc();
1490 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1491 if (!indexCast)
1492 return failure();
1493
1494 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1495
1496 auto newExtract = tensor::ExtractOp::create(
1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1498
1499 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1500 newExtract);
1501
1502 return success();
1503 }
1504};
1505
1506} // namespace
1507
1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1509 MLIRContext *context) {
1510 results.add<ExtractElementFromIndexCast>(context);
1511}
1512
1513//===----------------------------------------------------------------------===//
1514// GatherOp
1515//===----------------------------------------------------------------------===//
1516
1517void GatherOp::getAsmResultNames(
1518 function_ref<void(Value, StringRef)> setNameFn) {
1519 setNameFn(getResult(), "gather");
1520}
1521
1522/// Return the inferred result type for a gatherOp where:
1523/// - sourceType is the type of the source tensor gathered from
1524/// - indicesType is the type of the indices used to gather
1525/// - gatherDims are the dims along which the gather occurs.
1526/// Return a full rank or ranked-reduced variant of the type depending on
1527/// the value of rankReduced.
1528///
1529/// The leading dimensions of the index tensor give the result tensor its
1530/// leading dimensions.
1531/// The trailing dimensions of the result tensor are obtained from the source
1532/// tensor by setting the dimensions specified in gather_dims to `1` (if
1533/// rankedReduced is false), or skipping them (otherwise).
1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1535 RankedTensorType indicesType,
1536 ArrayRef<int64_t> gatherDims,
1537 bool rankReduced) {
1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1539 resultShape.reserve(resultShape.size() + sourceType.getRank());
1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1541 if (llvm::binary_search(gatherDims, idx)) {
1542 if (!rankReduced)
1543 resultShape.push_back(1);
1544 continue;
1545 }
1546 resultShape.push_back(sourceType.getDimSize(idx));
1547 }
1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1549}
1550
1551static LogicalResult
1554 StringRef gatherOrScatter, StringRef sourceOrDest) {
1555 if (dims.empty())
1556 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1557
1558 int64_t numGatherDims = dims.size();
1559 if (numGatherDims > rank)
1560 return op->emitOpError(gatherOrScatter)
1561 << "_dims overflow " << sourceOrDest << " rank";
1562 if (indices.empty() || indices.back() != numGatherDims)
1563 return op->emitOpError(gatherOrScatter)
1564 << "_dims length must match the size of last dimension of indices";
1565 for (int64_t val : dims) {
1566 if (val < 0)
1567 return op->emitOpError(gatherOrScatter)
1568 << "_dims value must be non-negative";
1569 if (val >= rank)
1570 return op->emitOpError(gatherOrScatter)
1571 << "_dims value must be smaller than " << sourceOrDest << " rank";
1572 }
1573 for (int64_t i = 1; i < numGatherDims; ++i) {
1574 if (dims[i - 1] >= dims[i])
1575 return op->emitOpError(gatherOrScatter)
1576 << "_dims values must be strictly increasing";
1577 }
1578 return success();
1579}
1580
1581LogicalResult GatherOp::verify() {
1582 int64_t sourceRank = getSourceType().getRank();
1583 ArrayRef<int64_t> gatherDims = getGatherDims();
1584 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1585 getIndicesType().getShape(), sourceRank,
1586 "gather", "source")))
1587 return failure();
1588
1589 RankedTensorType expectedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1593 if (getResultType() != expectedResultType &&
1594 getResultType() != expectedRankReducedResultType) {
1595 return emitOpError("result type "
1596 "mismatch: "
1597 "expected ")
1598 << expectedResultType << " or its rank-reduced variant "
1599 << expectedRankReducedResultType << " (got: " << getResultType()
1600 << ")";
1601 }
1602
1603 return success();
1604}
1605
1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1607 if (OpFoldResult reshapedSource = reshapeConstantSource(
1608 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1609 getResult().getType()))
1610 return reshapedSource;
1611 return {};
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// InsertOp
1616//===----------------------------------------------------------------------===//
1617
1618void InsertOp::getAsmResultNames(
1619 function_ref<void(Value, StringRef)> setNameFn) {
1620 setNameFn(getResult(), "inserted");
1621}
1622
1623LogicalResult InsertOp::verify() {
1624 // Verify the # indices match if we have a ranked type.
1625 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1626 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1627 return emitOpError("incorrect number of indices");
1628 return success();
1629}
1630
1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1632 Attribute scalar = adaptor.getScalar();
1633 Attribute dest = adaptor.getDest();
1634 if (scalar && dest)
1635 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1636 if (scalar == splatDest.getSplatValue<Attribute>())
1637 return dest;
1638 return {};
1639}
1640
1641//===----------------------------------------------------------------------===//
1642// GenerateOp
1643//===----------------------------------------------------------------------===//
1644
1645void GenerateOp::getAsmResultNames(
1646 function_ref<void(Value, StringRef)> setNameFn) {
1647 setNameFn(getResult(), "generated");
1648}
1649
1650LogicalResult GenerateOp::reifyResultShapes(
1651 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1652 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1653 int idx = 0;
1654 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1655 if (getType().isDynamicDim(dim)) {
1656 reifiedReturnShapes[0][dim] = getOperand(idx++);
1657 } else {
1658 reifiedReturnShapes[0][dim] =
1659 builder.getIndexAttr(getType().getDimSize(dim));
1660 }
1661 }
1662 return success();
1663}
1664
1665LogicalResult GenerateOp::verify() {
1666 // Ensure that the tensor type has as many dynamic dimensions as are
1667 // specified by the operands.
1668 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1669 if (getNumOperands() != resultType.getNumDynamicDims())
1670 return emitError("must have as many index operands as dynamic extents "
1671 "in the result type");
1672 return success();
1673}
1674
1675LogicalResult GenerateOp::verifyRegions() {
1676 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1677 // Ensure that region arguments span the index space.
1678 if (!llvm::all_of(getBody().getArgumentTypes(),
1679 [](Type ty) { return ty.isIndex(); }))
1680 return emitError("all body arguments must be index");
1681 if (getBody().getNumArguments() != resultTy.getRank())
1682 return emitError("must have one body argument per input dimension");
1683
1684 // Ensure that the region yields an element of the right type.
1685 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1686
1687 if (yieldOp.getValue().getType() != resultTy.getElementType())
1688 return emitOpError(
1689 "body must be terminated with a `yield` operation of the tensor "
1690 "element type");
1691
1692 return success();
1693}
1694
1695void GenerateOp::build(
1696 OpBuilder &b, OperationState &result, Type resultTy,
1697 ValueRange dynamicExtents,
1698 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1699 build(b, result, resultTy, dynamicExtents);
1700
1701 // Build and populate body.
1702 OpBuilder::InsertionGuard guard(b);
1703 Region *bodyRegion = result.regions.front().get();
1704 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1705 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1706 SmallVector<Location, 2> argumentLocs(rank, result.location);
1707 Block *bodyBlock =
1708 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1709 bodyBuilder(b, result.location, bodyBlock->getArguments());
1710}
1711
1712namespace {
1713
1714/// Canonicalizes tensor.generate operations with a constant
1715/// operand into the equivalent operation with the operand expressed in the
1716/// result type, instead. We also insert a type cast to make sure that the
1717/// resulting IR is still well-typed.
1718struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1719 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1720
1721 LogicalResult matchAndRewrite(GenerateOp generateOp,
1722 PatternRewriter &rewriter) const final {
1723 SmallVector<Value> foldedDynamicSizes;
1724 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1725 generateOp.getType(), generateOp.getDynamicExtents(),
1726 foldedDynamicSizes);
1727
1728 // Stop here if no dynamic size was promoted to static.
1729 if (foldedTensorType == generateOp.getType())
1730 return failure();
1731
1732 auto loc = generateOp.getLoc();
1733 auto newOp =
1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1735 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1736 newOp.getBody().begin());
1737 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1738 generateOp.getType(), newOp);
1739 return success();
1740 }
1741};
1742
1743/// Canonicalizes the pattern of the form
1744///
1745/// %tensor = tensor.generate %x {
1746/// ^bb0(%arg0: index):
1747/// <computation>
1748/// yield %1 : index
1749/// } : tensor<?xindex>
1750/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1751///
1752/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1753/// tensor.generate operation has no side-effects.
1754struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1755 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1756
1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1758 PatternRewriter &rewriter) const final {
1759 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1760 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1761 return failure();
1762
1763 IRMapping mapping;
1764 Block *body = &tensorFromElements.getBody().front();
1765 mapping.map(body->getArguments(), extract.getIndices());
1766 for (auto &op : body->without_terminator())
1767 rewriter.clone(op, mapping);
1768
1769 auto yield = cast<YieldOp>(body->getTerminator());
1770
1771 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1772 return success();
1773 }
1774};
1775
1776} // namespace
1777
1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779 MLIRContext *context) {
1780 // TODO: Move extract pattern to tensor::ExtractOp.
1781 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1782}
1783
1784//===----------------------------------------------------------------------===//
1785// RankOp
1786//===----------------------------------------------------------------------===//
1787
1788void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1789 setNameFn(getResult(), "rank");
1790}
1791
1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1793 // Constant fold rank when the rank of the operand is known.
1794 auto type = getOperand().getType();
1795 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1796 if (shapedType && shapedType.hasRank())
1797 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1798 return IntegerAttr();
1799}
1800
1801//===----------------------------------------------------------------------===//
1802// ReshapeOp
1803//===----------------------------------------------------------------------===//
1804
1805void ReshapeOp::getAsmResultNames(
1806 function_ref<void(Value, StringRef)> setNameFn) {
1807 setNameFn(getResult(), "reshape");
1808}
1809
1810static int64_t getNumElements(ShapedType type) {
1811 int64_t numElements = 1;
1812 for (auto dim : type.getShape())
1813 numElements *= dim;
1814 return numElements;
1815}
1816
1817LogicalResult ReshapeOp::verify() {
1818 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1819 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1820
1821 if (operandType.getElementType() != resultType.getElementType())
1822 return emitOpError("element types of source and destination tensor "
1823 "types should be the same");
1824
1825 int64_t shapeSize =
1826 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1827 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1828 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1829
1830 if (resultRankedType) {
1831 if (operandRankedType && resultRankedType.hasStaticShape() &&
1832 operandRankedType.hasStaticShape()) {
1833 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1834 return emitOpError("source and destination tensor should have the "
1835 "same number of elements");
1836 }
1837 if (ShapedType::isDynamic(shapeSize))
1838 return emitOpError("cannot use shape operand with dynamic length to "
1839 "reshape to statically-ranked tensor type");
1840 if (shapeSize != resultRankedType.getRank())
1841 return emitOpError(
1842 "length of shape operand differs from the result's tensor rank");
1843 }
1844 return success();
1845}
1846
1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1848 if (OpFoldResult reshapedSource = reshapeConstantSource(
1849 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1850 getResult().getType()))
1851 return reshapedSource;
1852
1853 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1854 // producer's input instead as the original tensor to reshape. This could
1855 // render such producer dead code.
1856 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1857 getSourceMutable().assign(reshapeOpProducer.getSource());
1858 return getResult();
1859 }
1860
1861 auto source = getSource();
1862 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1863 auto resultTy = dyn_cast<RankedTensorType>(getType());
1864 if (!sourceTy || !resultTy || sourceTy != resultTy)
1865 return {};
1866
1867 // If the source and result are both 0D or 1D tensors and have the same type,
1868 // the reshape has no effect, even if the tensor is dynamically shaped.
1869 if (sourceTy.getRank() <= 1)
1870 return source;
1871
1872 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1873 auto elements = fromElements.getElements();
1874 bool dynamicNoop =
1875 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1876 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1877 auto element = elements[id];
1878
1879 if (auto cst = getConstantIntValue(element)) {
1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1881 continue;
1882 }
1883
1884 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1885 dynamicNoop &= dimOp.getSource() == source;
1886
1887 auto cst = getConstantIntValue(dimOp.getIndex());
1888 dynamicNoop &=
1889 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1890 continue;
1891 }
1892
1893 dynamicNoop = false;
1894 break;
1895 }
1896
1897 if (dynamicNoop)
1898 return source;
1899 }
1900
1901 return {};
1902}
1903
1904//===----------------------------------------------------------------------===//
1905// Reassociative reshape ops
1906//===----------------------------------------------------------------------===//
1907
1908void CollapseShapeOp::getAsmResultNames(
1909 function_ref<void(Value, StringRef)> setNameFn) {
1910 setNameFn(getResult(), "collapsed");
1911}
1912
1913void ExpandShapeOp::getAsmResultNames(
1914 function_ref<void(Value, StringRef)> setNameFn) {
1915 setNameFn(getResult(), "expanded");
1916}
1917
1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1920 "invalid resultDim");
1921 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1922 if (llvm::is_contained(it.value(), resultDim))
1923 return it.index();
1924 llvm_unreachable("could not find reassociation group");
1925}
1926
1927FailureOr<SmallVector<OpFoldResult>>
1928ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1929 RankedTensorType expandedType,
1930 ArrayRef<ReassociationIndices> reassociation,
1931 ArrayRef<OpFoldResult> inputShape) {
1932 std::optional<SmallVector<OpFoldResult>> outputShape =
1933 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1934 inputShape);
1935 if (!outputShape)
1936 return failure();
1937 return *outputShape;
1938}
1939
1940SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1941 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1942}
1943
1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1945 Type resultType, Value src,
1946 ArrayRef<ReassociationIndices> reassociation,
1947 ArrayRef<OpFoldResult> outputShape) {
1948 auto [staticOutputShape, dynamicOutputShape] =
1949 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1950 build(builder, result, cast<RankedTensorType>(resultType), src,
1951 getReassociationIndicesAttribute(builder, reassociation),
1952 dynamicOutputShape, staticOutputShape);
1953}
1954
1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1956 Type resultType, Value src,
1957 ArrayRef<ReassociationIndices> reassociation) {
1958 SmallVector<OpFoldResult> inputShape =
1959 getMixedSizes(builder, result.location, src);
1960 auto tensorResultTy = cast<RankedTensorType>(resultType);
1961 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1962 builder, result.location, tensorResultTy, reassociation, inputShape);
1963 SmallVector<OpFoldResult> outputShapeOrEmpty;
1964 if (succeeded(outputShape)) {
1965 outputShapeOrEmpty = *outputShape;
1966 }
1967 build(builder, result, tensorResultTy, src, reassociation,
1968 outputShapeOrEmpty);
1969}
1970
1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1972 return getSymbolLessAffineMaps(getReassociationExprs());
1973}
1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1976 getReassociationIndices());
1977}
1978
1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1980 return getSymbolLessAffineMaps(getReassociationExprs());
1981}
1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1985}
1986
1987RankedTensorType CollapseShapeOp::inferCollapsedType(
1988 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1989 return inferCollapsedType(
1991 type.getContext(), reassociation)));
1992}
1993
1994/// Compute the RankedTensorType obtained by applying `reassociation` to
1995/// `type`.
1996RankedTensorType
1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1998 ArrayRef<AffineMap> reassociation) {
1999 auto shape = type.getShape();
2000 SmallVector<int64_t, 4> newShape;
2001 newShape.reserve(reassociation.size());
2002
2003 // Use the fact that reassociation is valid to simplify the logic: only use
2004 // each map's rank.
2005 assert(isReassociationValid(reassociation) && "invalid reassociation");
2006 unsigned currentDim = 0;
2007 for (AffineMap m : reassociation) {
2008 unsigned dim = m.getNumResults();
2009 auto band = shape.slice(currentDim, dim);
2010 int64_t size = 1;
2011 if (llvm::is_contained(band, ShapedType::kDynamic))
2012 size = ShapedType::kDynamic;
2013 else
2014 for (unsigned d = 0; d < dim; ++d)
2015 size *= shape[currentDim + d];
2016 newShape.push_back(size);
2017 currentDim += dim;
2018 }
2019
2020 return RankedTensorType::get(newShape, type.getElementType());
2021}
2022
2023void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2024 ArrayRef<ReassociationIndices> reassociation,
2025 ArrayRef<NamedAttribute> attrs) {
2026 auto resultType = inferCollapsedType(
2027 llvm::cast<RankedTensorType>(src.getType()),
2029 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2030 result.addAttribute(getReassociationAttrStrName(),
2031 getReassociationIndicesAttribute(b, reassociation));
2032 build(b, result, resultType, src, attrs);
2033}
2034
2035template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2037static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2040 if (failed(
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2042 return failure();
2043
2044 auto maps = op.getReassociationMaps();
2045 RankedTensorType expectedType =
2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2047 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2048 return op.emitOpError("expected collapsed type to be ")
2049 << expectedType << ", but got " << collapsedType;
2050 return success();
2051}
2052
2053LogicalResult ExpandShapeOp::verify() {
2054 auto srcType = getSrcType();
2055 auto resultType = getResultType();
2056
2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2058 return emitOpError("expected number of static shape dims to be equal to "
2059 "the output rank (")
2060 << resultType.getRank() << ") but found "
2061 << getStaticOutputShape().size() << " inputs instead";
2062
2063 if ((int64_t)getOutputShape().size() !=
2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2065 return emitOpError("mismatch in dynamic dims in output_shape and "
2066 "static_output_shape: static_output_shape has ")
2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2068 << " dynamic dims while output_shape has " << getOutputShape().size()
2069 << " values";
2070
2071 return verifyTensorReshapeOp(*this, resultType, srcType);
2072}
2073
2074LogicalResult CollapseShapeOp::verify() {
2075 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2076}
2077
2078namespace {
2079/// Reshape of a splat constant can be replaced with a constant of the result
2080/// type.
2081template <typename TensorReshapeOp>
2082struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2083 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2085 PatternRewriter &rewriter) const override {
2086 DenseElementsAttr attr;
2087 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2088 return failure();
2089 if (!attr || !attr.isSplat())
2090 return failure();
2091 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2092 reshapeOp.getResultType(), attr.getRawData());
2093 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2094 return success();
2095 }
2096};
2097
2098// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2099template <typename TensorReshapeOp>
2100class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2101public:
2102 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2103
2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2105 PatternRewriter &rewriter) const override {
2106 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2108 return failure();
2109
2110 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2112 return success();
2113 }
2114};
2115
2116/// Reshape of a FromElements can be replaced with a FromElements of the
2117/// result type
2118template <typename TensorReshapeOp>
2119struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter) const override {
2123 auto fromElements =
2124 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2125 if (!fromElements)
2126 return failure();
2127
2128 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2129
2130 if (!shapedTy.hasStaticShape())
2131 return failure();
2132
2133 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2134 fromElements.getElements());
2135 return success();
2136 }
2137};
2138
2139// Fold CastOp into CollapseShapeOp when adding static information.
2140struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2141 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2142
2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2144 PatternRewriter &rewriter) const override {
2145 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2146 if (!tensor::canFoldIntoConsumerOp(castOp))
2147 return failure();
2148
2149 RankedTensorType srcType =
2150 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2152 srcType, collapseShapeOp.getReassociationMaps());
2153
2154 if (newResultType == collapseShapeOp.getResultType()) {
2155 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2157 });
2158 } else {
2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2160 newResultType, castOp.getSource(),
2161 collapseShapeOp.getReassociation());
2162 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2164 }
2165 return success();
2166 }
2167};
2168
2169/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2170/// matching constant output_shape operands of the expand. This makes the
2171/// `tensor.expand_shape` more static and creates a consumer cast that can be
2172/// propagated further.
2173struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2174 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2175
2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2177 PatternRewriter &rewriter) const override {
2178 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2179 if (!canFoldIntoConsumerOp(castOp))
2180 return failure();
2181
2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2183 SmallVector<ReassociationIndices, 4> reassoc =
2184 expandOp.getReassociationIndices();
2185
2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2187 SmallVector<Value> dynamicOutputShape;
2188 auto outputIt = expandOp.getOutputShape().begin();
2189
2190 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2191 for (uint64_t outDim : innerReassoc) {
2192 if (ShapedType::isStatic(newOutputShape[outDim]))
2193 continue;
2194
2195 // If the cast's src type is dynamic, don't infer any of the
2196 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2197 // least one of the expanded dimensions to be dynamic if the input is
2198 // dynamic.
2199 Value val = *outputIt;
2200 ++outputIt;
2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2202 dynamicOutputShape.push_back(val);
2203 continue;
2204 }
2205
2206 APInt cst;
2207 if (matchPattern(val, m_ConstantInt(&cst))) {
2208 newOutputShape[outDim] = cst.getSExtValue();
2209 } else {
2210 dynamicOutputShape.push_back(val);
2211 }
2212 }
2213 }
2214
2215 // Couldn't match any values, nothing to change
2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2217 return failure();
2218
2219 // Calculate the input shape from the output
2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2221 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2222 for (auto outDim : reassoc[inDim]) {
2223 auto ofr = newOutputShape[outDim];
2224 if (ShapedType::isDynamic(ofr)) {
2225 newInputShape[inDim] = ShapedType::kDynamic;
2226 break;
2227 }
2228 newInputShape[inDim] *= ofr;
2229 }
2230 }
2231
2232 SmallVector<OpFoldResult> outputOfr =
2233 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2234 auto inputType = RankedTensorType::get(
2235 newInputShape, expandOp.getSrcType().getElementType());
2236 auto outputType = RankedTensorType::get(
2237 newOutputShape, expandOp.getSrcType().getElementType());
2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2239 expandOp.getSrc());
2240 auto newExpand = ExpandShapeOp::create(
2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2242 expandOp.getReassociationIndices(), outputOfr);
2243 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2244 newExpand.getResult());
2245 return success();
2246 }
2247};
2248} // namespace
2249
2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2251 MLIRContext *context) {
2252 results.add<
2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2255 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2256 FoldReshapeWithSplat<ExpandShapeOp>,
2257 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2258}
2259
2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2261 MLIRContext *context) {
2262 results.add<
2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2265 tensor::DimOp, RankedTensorType>,
2266 FoldReshapeWithConstant<CollapseShapeOp>,
2267 FoldReshapeWithSplat<CollapseShapeOp>,
2268 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2269 context);
2270}
2271
2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2274 adaptor.getOperands());
2275}
2276
2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2279 adaptor.getOperands());
2280}
2281
2282//===----------------------------------------------------------------------===//
2283// ExtractSliceOp
2284//===----------------------------------------------------------------------===//
2285
2286void ExtractSliceOp::getAsmResultNames(
2287 function_ref<void(Value, StringRef)> setNameFn) {
2288 setNameFn(getResult(), "extracted_slice");
2289}
2290
2291/// An extract_slice result type can be inferred, when it is not
2292/// rank-reduced, from the source type and the static representation of
2293/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2294RankedTensorType ExtractSliceOp::inferResultType(
2295 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2296 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2297 // An extract_slice op may specify only a leading subset of offset/sizes/
2298 // strides in which case we complete with offset=0, sizes from memref type
2299 // and strides=1.
2300 assert(static_cast<int64_t>(staticSizes.size()) ==
2301 sourceTensorType.getRank() &&
2302 "unexpected staticSizes not equal to rank of source");
2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2304 sourceTensorType.getEncoding());
2305}
2306
2307// TODO: This uses neither offsets nor strides!
2308RankedTensorType ExtractSliceOp::inferResultType(
2309 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2310 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2311 SmallVector<int64_t> staticSizes;
2312 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2313 assert(static_cast<int64_t>(staticSizes.size()) ==
2314 sourceTensorType.getRank() &&
2315 "unexpected staticSizes not equal to rank of source");
2316 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2317 sourceTensorType.getEncoding());
2318}
2319
2320/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2321/// number of sizes), drop as many size 1 as needed to produce an inferred
2322/// type with the desired rank.
2323///
2324/// Note that there may be multiple ways to compute this rank-reduced type:
2325/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2326///
2327/// To disambiguate, this function always drops the first 1 sizes occurrences.
2328RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2329 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2330 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2331 ArrayRef<int64_t> strides) {
2332 // Type inferred in the absence of rank-reducing behavior.
2333 auto inferredType = llvm::cast<RankedTensorType>(
2334 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2335 int rankDiff = inferredType.getRank() - desiredResultRank;
2336 if (rankDiff > 0) {
2337 auto shape = inferredType.getShape();
2338 llvm::SmallBitVector dimsToProject =
2339 getPositionsOfShapeOne(rankDiff, shape);
2340 SmallVector<int64_t> projectedShape;
2341 // Best effort rank-reducing: drop 1s in order.
2342 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2343 if (!dimsToProject.test(pos))
2344 projectedShape.push_back(shape[pos]);
2345 inferredType =
2346 RankedTensorType::get(projectedShape, inferredType.getElementType());
2347 }
2348 return inferredType;
2349}
2350
2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2354 ArrayRef<OpFoldResult> strides) {
2355 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2356 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2357 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2358 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2359 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2360 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2361 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2362 staticStrides);
2363}
2364
2365/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2366/// result type. If the type passed is nullptr, it is inferred.
2367void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2368 RankedTensorType resultType, Value source,
2369 ArrayRef<OpFoldResult> offsets,
2370 ArrayRef<OpFoldResult> sizes,
2371 ArrayRef<OpFoldResult> strides,
2372 ArrayRef<NamedAttribute> attrs) {
2373 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2374 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2375 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2376 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2377 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2378 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2379 // Structuring implementation this way avoids duplication between builders.
2380 if (!resultType) {
2381 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2382 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2383 }
2384 result.addAttributes(attrs);
2385 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2386 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2387 b.getDenseI64ArrayAttr(staticSizes),
2388 b.getDenseI64ArrayAttr(staticStrides));
2389}
2390
2391/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2392/// result type.
2393void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2394 ArrayRef<OpFoldResult> offsets,
2395 ArrayRef<OpFoldResult> sizes,
2396 ArrayRef<OpFoldResult> strides,
2397 ArrayRef<NamedAttribute> attrs) {
2398 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2399}
2400
2401/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2402/// a Range vector.
2403void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2404 ArrayRef<Range> ranges,
2405 ArrayRef<NamedAttribute> attrs) {
2406 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2407 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2408}
2409
2410/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2411/// the type passed is nullptr, it is inferred.
2412void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2413 RankedTensorType resultType, Value source,
2414 ValueRange offsets, ValueRange sizes,
2415 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2416 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2417 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2418 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2419 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2420 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2421 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2422 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2423}
2424
2425/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2426void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2427 ValueRange offsets, ValueRange sizes,
2428 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2429 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2430}
2431
2433 Operation *op,
2434 RankedTensorType expectedType) {
2435 switch (result) {
2437 return success();
2439 return op->emitError("expected rank to be smaller or equal to ")
2440 << "the other rank. ";
2442 return op->emitError("expected type to be ")
2443 << expectedType << " or a rank-reduced version. (size mismatch) ";
2445 return op->emitError("expected element type to be ")
2446 << expectedType.getElementType();
2447 default:
2448 llvm_unreachable("unexpected extract_slice op verification result");
2449 }
2450}
2451
2452/// Verifier for ExtractSliceOp.
2453LogicalResult ExtractSliceOp::verify() {
2454 RankedTensorType sourceType = getSourceType();
2455
2456 // Verify result type against inferred type.
2457 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2458 sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2461 return produceSliceErrorMsg(result, *this, expectedType);
2462
2463 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2464 // to the source tensor.
2465 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2466 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2467 getStaticStrides(), /*generateErrorMessage=*/true);
2468 if (!boundsResult.isValid)
2469 return getOperation()->emitError(boundsResult.errorMessage);
2470
2471 return success();
2472}
2473
2474llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2475 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2476}
2477
2478FailureOr<Value>
2479ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2480 ArrayRef<int64_t> desiredShape) {
2481 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2482 assert(sourceTensorType && "not a ranked tensor type");
2483 auto sourceShape = sourceTensorType.getShape();
2484 if (sourceShape.equals(desiredShape))
2485 return value;
2486 auto maybeRankReductionMask =
2487 mlir::computeRankReductionMask(sourceShape, desiredShape);
2488 if (!maybeRankReductionMask)
2489 return failure();
2491 b, loc, value,
2492 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2493}
2494
2495LogicalResult ExtractSliceOp::reifyResultShapes(
2496 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2497 reifiedReturnShapes.resize(1);
2498 reifiedReturnShapes[0].reserve(getType().getRank());
2499 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2500 llvm::SmallBitVector droppedDims = getDroppedDims();
2501 for (const auto &size : enumerate(mixedSizes)) {
2502 if (droppedDims.test(size.index()))
2503 continue;
2504 reifiedReturnShapes[0].push_back(size.value());
2505 }
2506 return success();
2507}
2508
2509namespace {
2510/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2511/// This essentially pushes memref_cast past its consuming slice when
2512/// `canFoldIntoConsumerOp` is true.
2513///
2514/// Example:
2515/// ```
2516/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2517/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2518/// tensor<3x4xf32>
2519/// ```
2520/// is rewritten into:
2521/// ```
2522/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2523/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2524/// ```
2525class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2526public:
2527 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2528
2529 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2530 PatternRewriter &rewriter) const override {
2531 // Any constant operand, just return to let the constant folder kick in.
2532 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2533 return matchPattern(operand, matchConstantIndex());
2534 }))
2535 return failure();
2536
2537 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2538 if (!castOp)
2539 return failure();
2540
2541 if (!canFoldIntoConsumerOp(castOp))
2542 return failure();
2543
2544 // Pattern does not apply if the produced op would not verify.
2545 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2546 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2547 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2548 sliceOp.getStaticStrides());
2549 if (!sliceResult.isValid)
2550 return failure();
2551
2552 // Create folded extract.
2553 Location loc = sliceOp.getLoc();
2554 Value newResult = ExtractSliceOp::create(
2555 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2556 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2557 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2558 sliceOp.getStaticStrides());
2559 rewriter.replaceOp(sliceOp, newResult);
2560 return success();
2561 }
2562};
2563
2564/// Slice elements from `values` into `outValues`. `counts` represents the
2565/// numbers of elements to stride in the original values for each dimension.
2566/// The output values can be used to construct a DenseElementsAttr.
2567template <typename IterTy, typename ElemTy>
2568static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2569 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2570 ArrayRef<int64_t> strides,
2571 llvm::SmallVectorImpl<ElemTy> *outValues) {
2572 assert(offsets.size() == sizes.size());
2573 assert(offsets.size() == strides.size());
2574 if (offsets.empty())
2575 return;
2576
2577 int64_t offset = offsets.front();
2578 int64_t size = sizes.front();
2579 int64_t stride = strides.front();
2580 if (offsets.size() == 1) {
2581 for (int64_t i = 0; i < size; ++i, offset += stride)
2582 outValues->push_back(*(values + offset));
2583
2584 return;
2585 }
2586
2587 for (int64_t i = 0; i < size; ++i, offset += stride) {
2588 auto begin = values + offset * counts.front();
2589 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2590 offsets.drop_front(), sizes.drop_front(),
2591 strides.drop_front(), outValues);
2592 }
2593}
2594
2595/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2596/// folded operation might introduce more constant data; Users can control
2597/// their heuristics by the control function.
2598class ConstantOpExtractSliceFolder final
2599 : public OpRewritePattern<ExtractSliceOp> {
2600public:
2601 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2602
2603 ConstantOpExtractSliceFolder(MLIRContext *context,
2605 : OpRewritePattern<ExtractSliceOp>(context),
2606 controlFn(std::move(controlFn)) {}
2607
2608 LogicalResult matchAndRewrite(ExtractSliceOp op,
2609 PatternRewriter &rewriter) const override {
2610 DenseElementsAttr attr;
2611 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2612 return failure();
2613
2614 // A constant splat is handled by fold().
2615 if (attr.isSplat())
2616 return failure();
2617
2618 // Dynamic result shape is not supported.
2619 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2620 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2621 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2622 return failure();
2623
2624 // Customized control over the folding.
2625 if (!controlFn(op))
2626 return failure();
2627
2628 int64_t count = sourceType.getNumElements();
2629 if (count == 0)
2630 return failure();
2631
2632 // Check if there are any dynamic parts, which are not supported.
2633 auto offsets = op.getStaticOffsets();
2634 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2635 return failure();
2636 auto sizes = op.getStaticSizes();
2637 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2638 return failure();
2639 auto strides = op.getStaticStrides();
2640 if (llvm::is_contained(strides, ShapedType::kDynamic))
2641 return failure();
2642
2643 // Compute the stride for each dimension.
2644 SmallVector<int64_t> counts;
2645 ArrayRef<int64_t> shape = sourceType.getShape();
2646 counts.reserve(shape.size());
2647 for (int64_t v : shape) {
2648 count = count / v;
2649 counts.push_back(count);
2650 }
2651
2652 // New attribute constructed by the sliced values.
2653 DenseElementsAttr newAttr;
2654
2655 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2656 SmallVector<APInt> outValues;
2657 outValues.reserve(sourceType.getNumElements());
2658 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2659 elems.begin(), counts, offsets, sizes, strides, &outValues);
2660 newAttr = DenseElementsAttr::get(resultType, outValues);
2661 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2662 SmallVector<APFloat> outValues;
2663 outValues.reserve(sourceType.getNumElements());
2664 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2665 elems.begin(), counts, offsets, sizes, strides, &outValues);
2666 newAttr = DenseElementsAttr::get(resultType, outValues);
2667 }
2668
2669 if (newAttr) {
2670 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2671 return success();
2672 }
2673
2674 return failure();
2675 }
2676
2677private:
2678 /// This additionally controls whether the fold happens or not. Users can
2679 /// impose their heuristics in the function.
2681};
2682
2683} // namespace
2684
2687 const ControlConstantExtractSliceFusionFn &controlFn) {
2688 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2689}
2690
2691/// Return the canonical type of the result of an extract_slice op.
2693 RankedTensorType operator()(ExtractSliceOp op,
2694 ArrayRef<OpFoldResult> mixedOffsets,
2695 ArrayRef<OpFoldResult> mixedSizes,
2696 ArrayRef<OpFoldResult> mixedStrides) {
2697 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2698 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2699 mixedStrides);
2700 }
2701};
2702
2703/// A canonicalizer wrapper to replace ExtractSliceOps.
2705 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2706 ExtractSliceOp newOp) {
2707 Value replacement = newOp.getResult();
2708 if (replacement.getType() != op.getType())
2709 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2710 replacement);
2711 rewriter.replaceOp(op, replacement);
2712 }
2713};
2714
2715void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2716 MLIRContext *context) {
2717 results.add<
2718 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2719 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2720 ExtractSliceOpCastFolder>(context);
2721}
2722
2723//
2724static LogicalResult
2725foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2726 ShapedType shapedType) {
2727 OpBuilder b(op.getContext());
2728 for (OpFoldResult ofr : op.getMixedOffsets())
2729 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2730 return failure();
2731 // Rank-reducing noops only need to inspect the leading dimensions:
2732 // llvm::zip is appropriate.
2733 auto shape = shapedType.getShape();
2734 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2735 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2736 return failure();
2737 for (OpFoldResult ofr : op.getMixedStrides())
2738 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2739 return failure();
2740 return success();
2741}
2742
2743/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2744/// slice, we can return the InsertSliceOp's source directly.
2745// TODO: This only checks the immediate producer; extend to go up the
2746// insert/extract chain if the slices are disjoint.
2747static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2748 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2749
2750 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2751 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2752 insertOp.isSameAs(extractOp, isSame))
2753 return insertOp.getSource();
2754
2755 return {};
2756}
2757
2758OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2759 if (OpFoldResult reshapedSource = reshapeConstantSource(
2760 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2761 getResult().getType()))
2762 return reshapedSource;
2763 if (getSourceType() == getType() &&
2765 return this->getSource();
2766 if (Value slice = foldExtractAfterInsertSlice(*this))
2767 return slice;
2768
2769 return OpFoldResult();
2770}
2771
2773 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2774 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2775 unsigned rank = rankedTensorType.getRank();
2776 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2778 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2779 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2780 offsets, sizes, strides);
2781}
2782
2783//===----------------------------------------------------------------------===//
2784// InsertSliceOp
2785//===----------------------------------------------------------------------===//
2786
2787void InsertSliceOp::getAsmResultNames(
2788 function_ref<void(Value, StringRef)> setNameFn) {
2789 setNameFn(getResult(), "inserted_slice");
2790}
2791
2792// Build a InsertSliceOp with mixed static and dynamic entries.
2793void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2794 Value dest, ArrayRef<OpFoldResult> offsets,
2796 ArrayRef<OpFoldResult> strides,
2798 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2799 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2800 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2801 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2802 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2803 result.addAttributes(attrs);
2804 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2805 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2806 b.getDenseI64ArrayAttr(staticSizes),
2807 b.getDenseI64ArrayAttr(staticStrides));
2808}
2809
2810/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2811/// Range vector.
2812void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2813 Value dest, ArrayRef<Range> ranges,
2814 ArrayRef<NamedAttribute> attrs) {
2815 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2816 build(b, result, source, dest, offsets, sizes, strides, attrs);
2817}
2818
2819// Build a InsertSliceOp with dynamic entries.
2820void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2821 Value dest, ValueRange offsets, ValueRange sizes,
2822 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2823 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2824 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2825 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2826 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2827 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2828 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2829 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2830}
2831
2832/// Rank-reducing type verification for both InsertSliceOp and
2833/// ParallelInsertSliceOp.
2835 RankedTensorType srcType, RankedTensorType dstType,
2836 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2837 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2838 // insert_slice is the inverse of extract_slice, use the same type
2839 // inference.
2840 RankedTensorType expected = ExtractSliceOp::inferResultType(
2841 dstType, staticOffsets, staticSizes, staticStrides);
2842 if (expectedType)
2843 *expectedType = expected;
2844 return isRankReducedType(expected, srcType);
2845}
2846
2847/// Verifier for InsertSliceOp.
2848LogicalResult InsertSliceOp::verify() {
2849 // Verify result type against inferred type.
2850 RankedTensorType expectedType;
2852 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2853 getStaticSizes(), getStaticStrides(), &expectedType);
2855 return produceSliceErrorMsg(result, *this, expectedType);
2856
2857 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2858 // to the destination tensor.
2859 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2860 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2861 getStaticStrides(), /*generateErrorMessage=*/true);
2862 if (!boundsResult.isValid)
2863 return getOperation()->emitError(boundsResult.errorMessage);
2864
2865 return success();
2866}
2867
2868/// If we have two consecutive InsertSliceOp writing to the same slice, we
2869/// can mutate the second InsertSliceOp's destination to the first one's.
2870///
2871/// Example:
2872///
2873/// ```mlir
2874/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2875/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2876/// ```
2877///
2878/// folds into:
2879///
2880/// ```mlir
2881/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2882/// ```
2883///
2884/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2885static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2886 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2887
2888 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2889 if (!prevInsertOp ||
2890 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2891 !prevInsertOp.isSameAs(insertOp, isSame))
2892 return failure();
2893
2894 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2895 return success();
2896}
2897
2898/// Folds round-trip extract/insert slice op pairs.
2899/// Example:
2900/// ```mlir
2901/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2902/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2903/// ```
2904/// can be folded into %val.
2905static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2906 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2907
2908 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2909 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2910 !extractOp.isSameAs(insertOp, isSame))
2911 return nullptr;
2912
2913 return extractOp.getSource();
2914}
2915
2916OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2917 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2918 getSourceType() == getType() &&
2920 return this->getSource();
2921 if (succeeded(foldInsertAfterInsertSlice(*this)))
2922 return getResult();
2923 if (auto result = foldInsertAfterExtractSlice(*this))
2924 return result;
2925 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2926 return getDest();
2927 return OpFoldResult();
2928}
2929
2930LogicalResult InsertSliceOp::reifyResultShapes(
2931 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2932 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2933 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2934 return success();
2935}
2936
2937namespace {
2938/// Pattern to rewrite a insert_slice op with constant arguments.
2939///
2940/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2941template <typename InsertOpTy>
2942class InsertSliceOpConstantArgumentFolder final
2943 : public OpRewritePattern<InsertOpTy> {
2944public:
2945 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2946
2947 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2948 PatternRewriter &rewriter) const override {
2949 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2950 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2951 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2952
2953 // No constant operands were folded, just return;
2954 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2955 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2956 failed(foldDynamicStrideList(mixedStrides)))
2957 return failure();
2958
2959 // Pattern does not apply if the produced op would not verify.
2960 SliceBoundsVerificationResult sliceResult =
2961 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2962 mixedOffsets, mixedSizes, mixedStrides);
2963 if (!sliceResult.isValid)
2964 return failure();
2965
2966 // Create the new op in canonical form.
2967 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2968 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2969 mixedOffsets, mixedSizes, mixedStrides);
2970 Value toInsert = insertSliceOp.getSource();
2971 if (sourceType != insertSliceOp.getSourceType()) {
2972 OpBuilder::InsertionGuard g(rewriter);
2973 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2974 // is that the insertion point is just before the InParallelOp in
2975 // the parallel case.
2976 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2977 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2978 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2979 sourceType, toInsert);
2980 }
2981 rewriter.replaceOpWithNewOp<InsertOpTy>(
2982 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2983 mixedSizes, mixedStrides);
2984 return success();
2985 }
2986};
2987
2988/// Fold tensor_casts with insert_slice operations. If the source or
2989/// destination tensor is a tensor_cast that removes static type information,
2990/// the cast is folded into the insert_slice operation. E.g.:
2991///
2992/// ```mlir
2993/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2994/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2995/// ```
2996///
2997/// folds into:
2998///
2999/// ```mlir
3000/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3001/// ```
3002///
3003/// Note: When folding a cast on the destination tensor, the result of the
3004/// insert_slice operation is casted to ensure that the type of the result did
3005/// not change.
3006///
3007/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3008template <typename InsertOpTy>
3009struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3010 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3011
3012 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3013 PatternRewriter &rewriter) const override {
3014 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3015 return matchPattern(operand, matchConstantIndex());
3016 }))
3017 return failure();
3018
3019 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3020 auto castOp = v.getDefiningOp<tensor::CastOp>();
3021 if (!castOp || !canFoldIntoConsumerOp(castOp))
3022 return std::nullopt;
3023 return castOp.getSource();
3024 };
3025 std::optional<Value> sourceCastSource =
3026 getSourceOfCastOp(insertSliceOp.getSource());
3027 std::optional<Value> destCastSource =
3028 getSourceOfCastOp(insertSliceOp.getDest());
3029 if (!sourceCastSource && !destCastSource)
3030 return failure();
3031
3032 auto src =
3033 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3034 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3035 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3036 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3037 if (!srcType || !dstType)
3038 return failure();
3039
3040 // The tensor.cast source could have additional static information not seen
3041 // in the insert slice op static sizes, so we ignore dynamic dims when
3042 // computing the rank reduction mask.
3043 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3044 auto rankReductionMask = computeRankReductionMask(
3045 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3046 if (!rankReductionMask.has_value())
3047 return failure();
3048 // Replace dimensions in the insert slice op with corresponding static dims
3049 // from the cast source type. If the insert slice sizes have static dims
3050 // that are not static in the tensor.cast source (i.e., when the cast op
3051 // casts a dynamic dim to static), the dim should not be replaced, and the
3052 // pattern will fail later in `verifyInsertSliceOp`.
3053 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3054 int64_t rankReducedIdx = 0;
3055 for (auto [idx, size] : enumerate(staticSizes)) {
3056 if (!rankReductionMask.value().contains(idx) &&
3057 !srcType.isDynamicDim(rankReducedIdx)) {
3058 mixedSizes[idx] = getAsIndexOpFoldResult(
3059 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3060 size = srcType.getDimSize(rankReducedIdx++);
3061 }
3062 }
3063
3064 // Pattern does not apply if the produced op would not verify.
3065 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3066 staticSizes, insertSliceOp.getStaticStrides()) !=
3067 SliceVerificationResult::Success)
3068 return failure();
3069 SliceBoundsVerificationResult sliceResult =
3070 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3071 mixedSizes, insertSliceOp.getMixedStrides());
3072 if (!sliceResult.isValid)
3073 return failure();
3074
3075 Operation *replacement =
3076 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3077 insertSliceOp.getMixedOffsets(), mixedSizes,
3078 insertSliceOp.getMixedStrides());
3079
3080 // In the parallel case there is no result and so nothing to cast.
3081 bool isParallelInsert =
3082 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3083 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3084 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3085 insertSliceOp.getDestType(),
3086 replacement->getResult(0));
3087 }
3088 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3089 return success();
3090 }
3091};
3092
3093/// If additional static type information can be deduced from a insert_slice's
3094/// size operands, insert an explicit cast of the op's source operand. This
3095/// enables other canonicalization patterns that are matching for tensor_cast
3096/// ops such as `ForOpTensorCastFolder` in SCF.
3097///
3098/// Example:
3099///
3100/// ```mlir
3101/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3102/// : tensor<?x?xf32> into ...
3103/// ```
3104///
3105/// folds into:
3106///
3107/// ```mlir
3108/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3109/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3110/// : tensor<64x64xf32> into ...
3111/// ```
3112///
3113/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3114template <typename InsertOpTy>
3115struct InsertSliceOpSourceCastInserter final
3116 : public OpRewritePattern<InsertOpTy> {
3117 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3118
3119 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3120 PatternRewriter &rewriter) const override {
3121 RankedTensorType srcType = insertSliceOp.getSourceType();
3122 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3123 return failure();
3124 SmallVector<int64_t> newSrcShape(srcType.getShape());
3125 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3126 if (std::optional<int64_t> constInt =
3127 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3128 // Bail on invalid IR.
3129 if (*constInt < 0)
3130 return failure();
3131 newSrcShape[i] = *constInt;
3132 }
3133 }
3134 if (!hasValidSizesOffsets(newSrcShape))
3135 return failure();
3136
3137 RankedTensorType newSrcType = RankedTensorType::get(
3138 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3139 if (srcType == newSrcType ||
3140 !preservesStaticInformation(srcType, newSrcType) ||
3141 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3142 return failure();
3143
3144 // newSrcType is:
3145 // 1) Different from srcType.
3146 // 2) "More static" than srcType.
3147 // 3) Cast-compatible with srcType.
3148 // Insert the cast.
3149 OpBuilder::InsertionGuard g(rewriter);
3150 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3151 // that the insertion point is just before the InParallelOp in the
3152 // parallel case.
3153 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3154 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3155 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3156 newSrcType, insertSliceOp.getSource());
3157 rewriter.replaceOpWithNewOp<InsertOpTy>(
3158 insertSliceOp, cast, insertSliceOp.getDest(),
3159 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3160 insertSliceOp.getMixedStrides());
3161 return success();
3162 }
3163};
3164} // namespace
3165
3166llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3167 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3168}
3169
3170void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3171 MLIRContext *context) {
3172 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3173 InsertSliceOpCastFolder<InsertSliceOp>,
3174 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3175}
3176
3178 Location loc,
3179 Value tensor,
3180 Value dest) {
3181 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3182 unsigned rank = rankedTensorType.getRank();
3183 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3184 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3185 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3186 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3187 sizes, strides);
3188}
3189
3190//===----------------------------------------------------------------------===//
3191// PadOp
3192//===----------------------------------------------------------------------===//
3193
3194void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3195 setNameFn(getResult(), "padded");
3196}
3197
3198LogicalResult PadOp::verify() {
3199 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3200 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3201 auto expectedType =
3202 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3203 if (!expectedType) {
3204 return emitError("failed to infer expectedType from sourceType ")
3205 << sourceType << ", specified resultType is " << resultType;
3206 }
3207 if (resultType.getRank() != expectedType.getRank()) {
3208 return emitError("specified type ")
3209 << resultType << " does not match the inferred type "
3210 << expectedType;
3211 }
3212 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3213 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3214 continue;
3215 if (expectedType.isDynamicDim(i))
3216 continue;
3217 return emitError("specified type ")
3218 << resultType << " does not match the inferred type "
3219 << expectedType;
3220 }
3221
3222 return success();
3223}
3224
3225LogicalResult PadOp::verifyRegions() {
3226 auto &region = getRegion();
3227 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3228 Block &block = region.front();
3229 if (block.getNumArguments() != rank)
3230 return emitError("expected the block to have ") << rank << " arguments";
3231
3232 // Note: the number and type of yield values are checked in the YieldOp.
3233 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3234 if (!en.value().isIndex())
3235 return emitOpError("expected block argument ")
3236 << (en.index() + 1) << " to be an index";
3237 }
3238
3239 // Ensure that the region yields an element of the right type.
3240 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3241 if (yieldOp.getValue().getType() !=
3242 llvm::cast<ShapedType>(getType()).getElementType())
3243 return emitOpError("expected yield type to match shape element type");
3244
3245 return success();
3246}
3247
3248RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3249 ArrayRef<int64_t> staticLow,
3250 ArrayRef<int64_t> staticHigh,
3251 ArrayRef<int64_t> resultShape) {
3252 unsigned rank = sourceType.getRank();
3253 if (staticLow.size() != rank)
3254 return RankedTensorType();
3255 if (staticHigh.size() != rank)
3256 return RankedTensorType();
3257 if (!resultShape.empty() && resultShape.size() != rank)
3258 return RankedTensorType();
3259
3260 SmallVector<int64_t, 4> inferredShape;
3261 for (auto i : llvm::seq<unsigned>(0, rank)) {
3262 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3263 staticHigh[i] == ShapedType::kDynamic) {
3264 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3265 : resultShape[i]);
3266 } else {
3267 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3268 assert((resultShape.empty() || size == resultShape[i] ||
3269 resultShape[i] == ShapedType::kDynamic) &&
3270 "mismatch between inferred shape and result shape");
3271 inferredShape.push_back(size);
3272 }
3273 }
3274
3275 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3276}
3277
3278void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3279 Value source, ArrayRef<int64_t> staticLow,
3280 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3281 bool nofold, ArrayRef<NamedAttribute> attrs) {
3282 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3283 if (!resultType)
3284 resultType = inferResultType(sourceType, staticLow, staticHigh);
3285 result.addAttributes(attrs);
3286 build(b, result, resultType, source, low, high,
3287 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3288 nofold ? b.getUnitAttr() : UnitAttr());
3289}
3290
3291void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3292 Value source, ValueRange low, ValueRange high, bool nofold,
3293 ArrayRef<NamedAttribute> attrs) {
3294 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3295 unsigned rank = sourceType.getRank();
3296 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3297 build(b, result, resultType, source, staticVector, staticVector, low, high,
3298 nofold, attrs);
3299}
3300
3301void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3302 Value source, ArrayRef<OpFoldResult> low,
3303 ArrayRef<OpFoldResult> high, bool nofold,
3304 ArrayRef<NamedAttribute> attrs) {
3305 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3306 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3307 SmallVector<int64_t, 4> staticLow, staticHigh;
3308 // staticLow and staticHigh have full information of the padding config.
3309 // This will grow staticLow and staticHigh with 1 value. If the config is
3310 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3311 // value as well.
3312 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3313 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3314 if (!resultType) {
3315 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3316 }
3317 assert(llvm::isa<RankedTensorType>(resultType));
3318 result.addAttributes(attrs);
3319 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3320 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3321 nofold ? b.getUnitAttr() : UnitAttr());
3322}
3323
3324void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3325 Value source, ArrayRef<OpFoldResult> low,
3326 ArrayRef<OpFoldResult> high, Value constantPadValue,
3327 bool nofold, ArrayRef<NamedAttribute> attrs) {
3328 build(b, result, resultType, source, low, high, nofold, attrs);
3329
3330 // Add a region and a block to yield the pad value.
3331 Region *region = result.regions[0].get();
3332 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3333 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3334 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3335
3336 // `builder.createBlock` changes the insertion point within the block. Create
3337 // a guard to reset the insertion point of the builder after it is destroyed.
3338 OpBuilder::InsertionGuard guard(b);
3339 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3340 tensor::YieldOp::create(b, result.location, constantPadValue);
3341}
3342
3343llvm::SmallBitVector PadOp::getPaddedDims() {
3344 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3345 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3346 for (const auto &en : enumerate(paddingWidths))
3347 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3348 paddedDims.set(en.index());
3349 };
3350 extractPaddedDims(getMixedLowPad());
3351 extractPaddedDims(getMixedHighPad());
3352 return paddedDims;
3353}
3354
3355namespace {
3356// Folds tensor.pad when padding is static zeros and the attribute
3357// doesn't request otherwise.
3358struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3359 using OpRewritePattern<PadOp>::OpRewritePattern;
3360
3361 LogicalResult matchAndRewrite(PadOp padTensorOp,
3362 PatternRewriter &rewriter) const override {
3363 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3364 return failure();
3365 if (padTensorOp.getNofold())
3366 return failure();
3367 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3368 padTensorOp, padTensorOp.getResult().getType(),
3369 padTensorOp.getSource());
3370 return success();
3371 }
3372};
3373
3374// Fold CastOp into PadOp when adding static information.
3375struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3376 using OpRewritePattern<PadOp>::OpRewritePattern;
3377
3378 LogicalResult matchAndRewrite(PadOp padTensorOp,
3379 PatternRewriter &rewriter) const override {
3380 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3381 if (!tensor::canFoldIntoConsumerOp(castOp))
3382 return failure();
3383
3384 auto newResultType = PadOp::inferResultType(
3385 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3386 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3387 padTensorOp.getResultType().getShape());
3388
3389 if (newResultType == padTensorOp.getResultType()) {
3390 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3391 padTensorOp.getSourceMutable().assign(castOp.getSource());
3392 });
3393 } else {
3394 auto newOp = PadOp::create(
3395 rewriter, padTensorOp->getLoc(), newResultType,
3396 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3397 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3398 padTensorOp.getHigh(), padTensorOp.getNofold(),
3399 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3400 IRMapping mapper;
3401 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3402
3403 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3404 padTensorOp, padTensorOp.getResultType(), newOp);
3405 }
3406 return success();
3407 }
3408};
3409
3410// Fold CastOp using the result of PadOp back into the latter if it adds
3411// static information.
3412struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3413 using OpRewritePattern<PadOp>::OpRewritePattern;
3414
3415 LogicalResult matchAndRewrite(PadOp padTensorOp,
3416 PatternRewriter &rewriter) const override {
3417 if (!padTensorOp.getResult().hasOneUse())
3418 return failure();
3419 auto tensorCastOp =
3420 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3421 if (!tensorCastOp)
3422 return failure();
3423 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3424 tensorCastOp.getDest().getType()))
3425 return failure();
3426
3427 auto replacementOp = PadOp::create(
3428 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3429 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3430 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3431 padTensorOp.getHigh(), padTensorOp.getNofold(),
3432 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3433 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3434
3435 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3436 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3437 return success();
3438 }
3439};
3440
3441/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3442/// different dimensions. The pattern applies if the following preconditions
3443/// hold:
3444/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3445/// 2) the tensor::ExtractSliceOps have only unit-strides,
3446/// 3) the tensor::PadOps perform only high-padding,
3447/// 4) the tensor::PadOps have the same constant padding value,
3448/// 5) the tensor::PadOps do not have common padding dimensions,
3449/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3450/// zero-offset for every dimension.
3451/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3452/// the
3453/// padded source dimensions.
3454///
3455/// Example:
3456///
3457/// ```mlir
3458/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3459/// : tensor<64x64xf32> to tensor<?x64xf32>
3460/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3461/// } : tensor<?x64xf32> to tensor<8x64xf32>
3462/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3463/// : tensor<8x64xf32> to tensor<8x?xf32>
3464/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3465/// } : tensor<8x?xf32> to tensor<8x4xf32>
3466/// ```
3467///
3468/// folds into:
3469///
3470/// ```mlir
3471/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3472/// : tensor<64x64xf32> to tensor<?x?xf32>
3473/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3474/// } : tensor<?x?xf32> to tensor<8x4xf32>
3475/// ```
3476struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3477 using OpRewritePattern<PadOp>::OpRewritePattern;
3478
3479 LogicalResult matchAndRewrite(PadOp padOp,
3480 PatternRewriter &rewriter) const override {
3481 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3482 if (!innerSliceOp)
3483 return failure();
3484 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3485 if (!outerPadOp || outerPadOp.getNofold())
3486 return failure();
3487 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3488 if (!outerSliceOp)
3489 return failure();
3490
3491 // 1) Fail if the chain is rank-reducing.
3492 int64_t rank = padOp.getSourceType().getRank();
3493 if (outerSliceOp.getSourceType().getRank() != rank) {
3494 return rewriter.notifyMatchFailure(padOp,
3495 "cannot fold rank-reducing chain");
3496 }
3497
3498 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3499 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3500 return rewriter.notifyMatchFailure(
3501 padOp, "cannot fold non-unit stride ExtractSliceOps");
3502 }
3503
3504 // 3) Fail if the tensor::PadOps have non-zero low padding.
3505 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3506 return rewriter.notifyMatchFailure(padOp,
3507 "cannot fold PadOps with low padding");
3508 }
3509
3510 // 4) Fail if the tensor::PadOps padding values do not match.
3511 Attribute innerAttr, outerAttr;
3512 Value innerValue = padOp.getConstantPaddingValue();
3513 Value outerValue = outerPadOp.getConstantPaddingValue();
3514 if (!innerValue || !outerValue ||
3515 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3516 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3517 innerAttr != outerAttr) {
3518 return rewriter.notifyMatchFailure(
3519 padOp, "cannot fold PadOps with different padding values");
3520 }
3521
3522 // 5) Fail if a dimension is padded by both tensor::PadOps.
3523 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3524 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3525 if (innerDims.anyCommon(outerDims)) {
3526 return rewriter.notifyMatchFailure(
3527 padOp, "cannot fold PadOps with common padding dimensions");
3528 }
3529
3530 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3531 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3532 // for every dimension, and use the offset the other pair. Fail if no
3533 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3534 // exists.
3535 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3536 for (auto en : enumerate(newOffsets)) {
3537 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3538 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3539 if (!innerDims.test(en.index()) &&
3540 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3541 en.value() = outerOffset;
3542 continue;
3543 }
3544 if (!outerDims.test(en.index()) &&
3545 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3546 en.value() = innerOffset;
3547 continue;
3548 }
3549 return rewriter.notifyMatchFailure(
3550 padOp, "cannot find zero-offset and zero-padding pair");
3551 }
3552
3553 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3554 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3555 // outer tensor::PadOp and fail if the size of the inner
3556 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3557 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3558 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3559 for (auto en : enumerate(newSizes)) {
3560 if (!outerDims.test(en.index()))
3561 continue;
3562 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3563 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3564 assert(ShapedType::isStatic(sourceSize) &&
3565 "expected padded dimension to have a static size");
3566 if (getConstantIntValue(sliceSize) != sourceSize) {
3567 return rewriter.notifyMatchFailure(
3568 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3569 "match the size of the outer padding");
3570 }
3571 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3572 }
3573
3574 // Combine the high paddings of the two tensor::PadOps.
3575 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3576 for (auto en : enumerate(newHighPad)) {
3577 if (innerDims.test(en.index()))
3578 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3579 if (outerDims.test(en.index()))
3580 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3581 }
3582
3583 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3584 // the two paddings in one step.
3585 auto newSliceOp = ExtractSliceOp::create(
3586 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3587 newSizes, innerSliceOp.getMixedStrides());
3588 auto newPadOp = PadOp::create(
3589 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3590 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3591 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3592 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3593 newPadOp.getRegion().begin());
3594 rewriter.replaceOp(padOp, newPadOp.getResult());
3595 return success();
3596 }
3597};
3598
3599struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3600 using OpRewritePattern<PadOp>::OpRewritePattern;
3601
3602 LogicalResult matchAndRewrite(PadOp padTensorOp,
3603 PatternRewriter &rewriter) const override {
3604 Value input = padTensorOp.getSource();
3605 if (!llvm::isa<RankedTensorType>(input.getType()))
3606 return failure();
3607 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3608 auto inputRank = inputDims.size();
3609
3610 auto oldResultType =
3611 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3612 if (!oldResultType)
3613 return failure();
3614
3615 auto outputDims = oldResultType.getShape();
3616
3617 // Extract the static info from the high and low operands.
3618 SmallVector<int64_t> constOperandsLow;
3619 SmallVector<Value> newLows;
3620 for (auto operand : padTensorOp.getLow()) {
3621 APSInt intOp;
3622 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3623 constOperandsLow.push_back(ShapedType::kDynamic);
3624 newLows.push_back(operand);
3625 continue;
3626 }
3627 constOperandsLow.push_back(intOp.getExtValue());
3628 }
3629 SmallVector<int64_t> constOperandsHigh;
3630 SmallVector<Value> newHighs;
3631 for (auto operand : padTensorOp.getHigh()) {
3632 APSInt intOp;
3633 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3634 constOperandsHigh.push_back(ShapedType::kDynamic);
3635 newHighs.push_back(operand);
3636 continue;
3637 }
3638 constOperandsHigh.push_back(intOp.getExtValue());
3639 }
3640
3641 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3642 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3643
3644 // Verify the op is well-formed.
3645 if (inputDims.size() != outputDims.size() ||
3646 inputDims.size() != constLow.size() ||
3647 inputDims.size() != constHigh.size())
3648 return failure();
3649
3650 auto lowCount = 0;
3651 auto highCount = 0;
3652 for (size_t i = 0; i < inputRank; i++) {
3653 if (constLow[i] == ShapedType::kDynamic)
3654 constLow[i] = constOperandsLow[lowCount++];
3655 if (constHigh[i] == ShapedType::kDynamic)
3656 constHigh[i] = constOperandsHigh[highCount++];
3657 }
3658
3659 auto staticLow = ArrayRef<int64_t>(constLow);
3660 auto staticHigh = ArrayRef<int64_t>(constHigh);
3661
3662 // Calculate the output sizes with the static information.
3663 SmallVector<int64_t> newOutDims;
3664 for (size_t i = 0; i < inputRank; i++) {
3665 if (outputDims[i] == ShapedType::kDynamic) {
3666 newOutDims.push_back(
3667 (staticLow[i] == ShapedType::kDynamic ||
3668 staticHigh[i] == ShapedType::kDynamic ||
3669 inputDims[i] == ShapedType::kDynamic
3670 ? ShapedType::kDynamic
3671 : inputDims[i] + staticLow[i] + staticHigh[i]));
3672 } else {
3673 newOutDims.push_back(outputDims[i]);
3674 }
3675 }
3676
3677 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3678 llvm::all_of(newOutDims,
3679 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3680 return failure();
3681
3682 // Rewrite the op using the new static type.
3683 auto newResultType = RankedTensorType::get(
3684 newOutDims, padTensorOp.getType().getElementType());
3685 auto newOp = PadOp::create(
3686 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3687 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3688 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3689
3690 IRMapping mapper;
3691 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3692 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3693 newOp);
3694
3695 return success();
3696 }
3697};
3698
3699/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3700///
3701/// Example:
3702///
3703/// ```mlir
3704/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3705/// tensor.yield %val
3706/// } : tensor<1x2xf32> to tensor<2x5xf32>
3707/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3708/// tensor.yield %val
3709/// } : tensor<1x5xf32> to tensor<5x7xf32>
3710/// ```
3711///
3712/// folds into:
3713///
3714/// ```mlir
3715/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3716/// tensor.yield %val
3717/// } : tensor<1x2xf32> to tensor<5x7xf32>
3718/// ```
3719struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3720 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3721
3722 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3723 PatternRewriter &rewriter) const override {
3724 if (padOp.getNofold()) {
3725 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3726 }
3727
3728 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3729 if (!producerPad || producerPad.getNofold()) {
3730 return rewriter.notifyMatchFailure(
3731 padOp, "producer is not a foldable tensor.pad op");
3732 }
3733
3734 // Fail if the tensor::PadOps padding values do not match.
3735 Value consumerPadValue = padOp.getConstantPaddingValue();
3736 Value producerPadValue = producerPad.getConstantPaddingValue();
3737 if (!consumerPadValue || !producerPadValue ||
3738 consumerPadValue != producerPadValue) {
3739 return rewriter.notifyMatchFailure(
3740 padOp,
3741 "cannot fold PadOps with different or non-constant padding values");
3742 }
3743
3744 Location loc = padOp.getLoc();
3745 AffineExpr d0, d1;
3746 bindDims(rewriter.getContext(), d0, d1);
3747
3748 // Combine the low/high paddings of the two tensor::PadOps.
3749 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3750 ArrayRef<OpFoldResult> producerPaddings) {
3751 SmallVector<OpFoldResult> sumPaddings;
3752 for (auto [consumerIndex, producerIndex] :
3753 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3754 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3755 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3756 }
3757 return sumPaddings;
3758 };
3759
3760 SmallVector<OpFoldResult> newHighPad =
3761 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3762 SmallVector<OpFoldResult> newLowPad =
3763 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3764
3765 auto newPadOp = tensor::PadOp::create(
3766 rewriter, padOp.getLoc(), padOp.getResultType(),
3767 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3768 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3769 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3770 newPadOp.getRegion().begin());
3771 rewriter.replaceOp(padOp, newPadOp.getResult());
3772 return success();
3773 }
3774};
3775
3776} // namespace
3777
3778LogicalResult
3779PadOp::reifyResultShapes(OpBuilder &b,
3780 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3781 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3782 SmallVector<OpFoldResult> lp = getMixedLowPad();
3783 SmallVector<OpFoldResult> hp = getMixedHighPad();
3784 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3785 if (!getType().isDynamicDim(i)) {
3786 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3787 continue;
3788 }
3789 Location loc = getLoc();
3790 Value dim = b.createOrFold<tensor::DimOp>(
3791 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3792
3793 AffineExpr d0, d1, d2;
3794 bindDims(b.getContext(), d0, d1, d2);
3795 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3796 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3797 }
3798 return success();
3799}
3800
3801void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3802 MLIRContext *context) {
3803 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3804 FoldOrthogonalPaddings, FoldStaticPadding,
3805 FoldConsecutiveConstantPadding>(context);
3806}
3807
3808/// Return the padding value of the PadOp if it constant. In this context,
3809/// "constant" means an actual constant or "defined outside of the block".
3810///
3811/// Values are considered constant in three cases:
3812/// - A ConstantLike value.
3813/// - A basic block argument from a different block.
3814/// - A value defined outside of the block.
3815///
3816/// If the padding value is not constant, an empty Value is returned.
3817Value PadOp::getConstantPaddingValue() {
3818 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3819 if (!yieldOp)
3820 return {};
3821 Value padValue = yieldOp.getValue();
3822 // Check if yield value is a constant.
3823 if (matchPattern(padValue, m_Constant()))
3824 return padValue;
3825 // Check if yield value is defined inside the PadOp block.
3826 if (padValue.getParentBlock() == &getRegion().front())
3827 return {};
3828 // Else: Yield value defined outside of the PadOp block.
3829 return padValue;
3830}
3831
3832OpFoldResult PadOp::fold(FoldAdaptor) {
3833 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3834 !getNofold())
3835 return getSource();
3836 return {};
3837}
3838
3839//===----------------------------------------------------------------------===//
3840// ParallelInsertSliceOp
3841//===----------------------------------------------------------------------===//
3842
3843OpResult ParallelInsertSliceOp::getTiedOpResult() {
3844 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3845 for (const auto &it :
3846 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3847 Operation &nextOp = it.value();
3848 if (&nextOp == getOperation())
3849 return parallelCombiningParent.getParentResult(it.index());
3850 }
3851 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3852}
3853
3854// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3855void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3856 Value source, Value dest,
3857 ArrayRef<OpFoldResult> offsets,
3858 ArrayRef<OpFoldResult> sizes,
3859 ArrayRef<OpFoldResult> strides,
3860 ArrayRef<NamedAttribute> attrs) {
3861 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3862 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3863 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3864 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3865 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3866 result.addAttributes(attrs);
3867 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3868 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3869 b.getDenseI64ArrayAttr(staticSizes),
3870 b.getDenseI64ArrayAttr(staticStrides));
3871}
3872
3873/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3874/// packed into a Range vector.
3875void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3876 Value source, Value dest,
3877 ArrayRef<Range> ranges,
3878 ArrayRef<NamedAttribute> attrs) {
3879 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3880 build(b, result, source, dest, offsets, sizes, strides, attrs);
3881}
3882
3883// Build a ParallelInsertSliceOp with dynamic entries.
3884void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3885 Value source, Value dest, ValueRange offsets,
3886 ValueRange sizes, ValueRange strides,
3887 ArrayRef<NamedAttribute> attrs) {
3888 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3889 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3890 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3891 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3892 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3893 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3894 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3895}
3896
3897LogicalResult ParallelInsertSliceOp::verify() {
3898 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3899 return this->emitError("expected InParallelOpInterface parent, got:")
3900 << *(getOperation()->getParentOp());
3901
3902 // Verify result type against inferred type.
3903 RankedTensorType expectedType;
3905 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3906 getStaticSizes(), getStaticStrides(), &expectedType);
3908 return produceSliceErrorMsg(result, *this, expectedType);
3909
3910 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3911 // to the destination tensor.
3912 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3913 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3914 getStaticStrides(), /*generateErrorMessage=*/true);
3915 if (!boundsResult.isValid)
3916 return getOperation()->emitError(boundsResult.errorMessage);
3917
3918 return success();
3919}
3920
3921void ParallelInsertSliceOp::getCanonicalizationPatterns(
3922 RewritePatternSet &results, MLIRContext *context) {
3923 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3924 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3925 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3926}
3927
3928llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3929 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3930}
3931
3932// ParallelCombiningOpInterface implementation.
3933MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3934 return getDestMutable();
3935}
3936
3937Operation *ParallelInsertSliceOp::getIteratingParent() {
3938 // Return the parent InParallelOpInterface's parent.
3939 if (auto combiningOp =
3940 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3941 return combiningOp->getParentOp();
3942 return nullptr;
3943}
3944
3945//===----------------------------------------------------------------------===//
3946// ScatterOp
3947//===----------------------------------------------------------------------===//
3948
3949void ScatterOp::getAsmResultNames(
3950 function_ref<void(Value, StringRef)> setNameFn) {
3951 setNameFn(getResult(), "scatter");
3952}
3953
3954LogicalResult ScatterOp::verify() {
3955 int64_t destRank = getDestType().getRank();
3956 ArrayRef<int64_t> scatterDims = getScatterDims();
3957 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3958 getIndicesType().getShape(), destRank,
3959 "scatter", "dest")))
3960 return failure();
3961
3962 if (!getUnique())
3963 return emitOpError("requires 'unique' attribute to be set");
3964 // TODO: we could also check statically that there are fewer leading index
3965 // tensor dims than the dest dims. If this is not the case, the unique
3966 // attribute cannot be true.
3967
3968 // Use the GatherOp::inferResultType on the `dest` type and verify the
3969 // expected type matches the source type.
3970 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3971 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3972 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3973 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3974 if (getSourceType() != expectedSourceType &&
3975 getSourceType() != expectedRankReducedSourceType) {
3976 return emitOpError("source type "
3977 "mismatch: "
3978 "expected ")
3979 << expectedSourceType << " or its rank-reduced variant "
3980 << expectedRankReducedSourceType << " (got: " << getSourceType()
3981 << ")";
3982 }
3983
3984 return success();
3985}
3986
3987//===----------------------------------------------------------------------===//
3988// SplatOp
3989//===----------------------------------------------------------------------===//
3990
3991void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3992 Type aggregateType, ValueRange dynamicSizes) {
3993 build(builder, result, aggregateType, element, dynamicSizes);
3994}
3995
3996void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3997 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3998 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3999 build(builder, result, aggregateType, element, dynamicSizes);
4000}
4001
4002void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4003 ArrayRef<OpFoldResult> sizes) {
4004 SmallVector<int64_t> staticShape;
4005 SmallVector<Value> dynamicSizes;
4006 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4007 build(builder, result, element, staticShape, dynamicSizes);
4008}
4009
4010void SplatOp::getAsmResultNames(
4011 function_ref<void(Value, StringRef)> setNameFn) {
4012 setNameFn(getResult(), "splat");
4013}
4014
4015LogicalResult SplatOp::verify() {
4016 if (getType().getNumDynamicDims() != getDynamicSizes().size())
4017 return emitOpError("incorrect number of dynamic sizes, has ")
4018 << getDynamicSizes().size() << ", expected "
4019 << getType().getNumDynamicDims();
4020 return success();
4021}
4022
4023LogicalResult
4024SplatOp::reifyResultShapes(OpBuilder &builder,
4025 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4026 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4027 unsigned ctr = 0;
4028 for (int64_t i = 0; i < getType().getRank(); ++i) {
4029 if (getType().isDynamicDim(i)) {
4030 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4031 } else {
4032 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4033 }
4034 }
4035 return success();
4036}
4037
4038OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4039 auto constOperand = adaptor.getInput();
4040 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4041 return {};
4042
4043 // Do not fold if the splat is not statically shaped
4044 if (!getType().hasStaticShape())
4045 return {};
4046
4047 // SplatElementsAttr::get treats single value for second arg as being a
4048 // splat.
4049 return SplatElementsAttr::get(getType(), {constOperand});
4050}
4051
4052//===----------------------------------------------------------------------===//
4053// Common Canonicalizers and Folders.
4054//===----------------------------------------------------------------------===//
4055static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4056 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4057 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4058 // might need special handling of attached regions.
4059 if (isa<InsertSliceOp>(op.getOperation()) ||
4060 isa<LoopLikeOpInterface>(op.getOperation()))
4061 return false;
4062
4064}
4065
4066/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4067/// the `tensor.cast` has source that is more static than the consuming op.
4068///
4069/// Example:
4070/// ```mlir
4071/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4072/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4073/// ```
4074///
4075/// folds into:
4076///
4077/// ```mlir
4078/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4079/// ```
4080/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4081/// can add the pattern to their canonicalizers.
4083 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4085 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4086
4087 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4088 PatternRewriter &rewriter) const override {
4089
4090 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4091 // for that instead.
4092 if (!foldTensorCastPrecondition(op) ||
4093 isa<linalg::RelayoutOpInterface>(*op))
4094 return failure();
4095
4096 SmallVector<Type> newResultTypes(op->getResultTypes());
4097 SmallVector<Value> newOperands =
4098 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4099
4100 // Clone op
4101 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4102
4103 SmallVector<Value, 4> replacements;
4104 replacements.reserve(newOp->getNumResults());
4105 for (auto [oldResult, newResult] :
4106 llvm::zip(op->getResults(), newOp->getResults())) {
4107 if (newResult.getType() != oldResult.getType()) {
4108 replacements.push_back(tensor::CastOp::create(
4109 rewriter, op->getLoc(), oldResult.getType(), newResult));
4110 } else {
4111 replacements.push_back(newResult);
4112 }
4113 }
4114 rewriter.replaceOp(op, replacements);
4115
4116 return success();
4117 }
4118};
4119
4120//===----------------------------------------------------------------------===//
4121// TensorDialect
4122//===----------------------------------------------------------------------===//
4123
4124void TensorDialect::getCanonicalizationPatterns(
4125 RewritePatternSet &results) const {
4126 results.add<FoldTensorCastProducerOp>(getContext());
4127}
4128
4129//===----------------------------------------------------------------------===//
4130// TableGen'd op method definitions
4131//===----------------------------------------------------------------------===//
4132
4133#define GET_OP_CLASSES
4134#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:75
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:90
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.