MLIR 23.0.0git
TensorOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19#include "mlir/IR/Builders.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Matchers.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/StringRef.h"
38#include "llvm/Support/Casting.h"
39#include "llvm/Support/MathExtras.h"
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::tensor;
44
45/// Materialize a single constant operation from a given attribute value with
46/// the desired resultant type.
47Operation *TensorDialect::materializeConstant(OpBuilder &builder,
48 Attribute value, Type type,
49 Location loc) {
50 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
51 return op;
52 if (complex::ConstantOp::isBuildableWith(value, type))
53 return complex::ConstantOp::create(builder, loc, type,
54 llvm::cast<ArrayAttr>(value));
55 return nullptr;
56}
57
59 int64_t dim) {
60 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
61 if (tensorType.isDynamicDim(dim))
62 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
63
64 return builder.getIndexAttr(tensorType.getDimSize(dim));
65}
66
68 Location loc, Value value) {
69 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
71 for (int64_t i = 0; i < tensorType.getRank(); ++i)
72 result.push_back(getMixedSize(builder, loc, value, i));
73 return result;
74}
75
77 OpResult opResult) {
78 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
79 assert(tensorType && "expected tensor type");
80
81 // If the op has a destination, it implements DestinationStyleOpInterface and
82 // we can query the destination operand from that interface.
83 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
84 if (destOp)
85 return destOp.getTiedOpOperand(opResult)->get();
86
87 // Otherwise, create a new destination tensor with the same shape.
89 b.setInsertionPoint(opResult.getDefiningOp());
90
91 // Compute sizes.
93 if (!tensorType.hasStaticShape()) {
94 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
95 ReifiedRankedShapedTypeDims reifiedShapes;
96 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
97 return failure();
98 mixedSizes = reifiedShapes[opResult.getResultNumber()];
99 } else {
100 // Static shape: Take static sizes directly.
101 for (int64_t sz : tensorType.getShape())
102 mixedSizes.push_back(b.getIndexAttr(sz));
103 }
104
105 // Create empty tensor.
106 Value emptyTensor =
107 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
108 return emptyTensor;
109}
110
112 Operation *op,
114 for (OpResult opResult : op->getResults()) {
115 if (llvm::isa<TensorType>(opResult.getType())) {
116 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
117 if (failed(destination))
118 return failure();
119 result.push_back(*destination);
120 }
121 }
122 return success();
123}
124
126 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
127 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
128 return rtp1.getShape() == rtp2.getShape() &&
129 rtp1.getElementType() == rtp2.getElementType();
130 return false;
131 }
132 return tp1 == tp2; // default implementation
133}
134
135/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
136/// rank-extending tensor.insert_slice op.
137static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
138 ArrayRef<OpFoldResult> mixedSizes) {
139 llvm::SmallBitVector droppedDims(mixedSizes.size());
140 int64_t shapePos = reducedShape.size() - 1;
141
142 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
143 size_t idx = mixedSizes.size() - size.index() - 1;
144 // Rank-reduced dims must have a static unit dimension.
145 bool isStaticUnitSize =
146 isa<Attribute>(size.value()) &&
147 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
148
149 if (shapePos < 0) {
150 // There are no more dims in the reduced shape. All remaining sizes must
151 // be rank-reduced dims.
152 assert(isStaticUnitSize && "expected unit dim");
153 droppedDims.set(idx);
154 continue;
155 }
156
157 // Dim is preserved if the size is not a static 1.
158 if (!isStaticUnitSize) {
159 --shapePos;
160 continue;
161 }
162
163 // Dim is preserved if the reduced shape dim is also 1.
164 if (reducedShape[shapePos] == 1) {
165 --shapePos;
166 continue;
167 }
168
169 // Otherwise: Dim is dropped.
170 droppedDims.set(idx);
171 }
172
173 assert(shapePos < 0 && "dimension mismatch");
174 return droppedDims;
175}
176
177/// Given a ranked tensor type and a range of values that defines its dynamic
178/// dimension sizes, turn all dynamic sizes that have a constant value into
179/// static dimension sizes.
180static RankedTensorType
181foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
182 SmallVector<Value> &foldedDynamicSizes) {
183 SmallVector<int64_t> staticShape(type.getShape());
184 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
185 "incorrect number of dynamic sizes");
186
187 // Compute new static and dynamic sizes.
188 unsigned ctr = 0;
189 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
190 if (type.isDynamicDim(i)) {
191 Value dynamicSize = dynamicSizes[ctr++];
192 std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
193 if (cst.has_value()) {
194 // Dynamic size must be non-negative.
195 if (cst.value() < 0) {
196 foldedDynamicSizes.push_back(dynamicSize);
197 continue;
198 }
199 staticShape[i] = *cst;
200 } else {
201 foldedDynamicSizes.push_back(dynamicSize);
202 }
203 }
204 }
205
206 return RankedTensorType::get(staticShape, type.getElementType(),
207 type.getEncoding());
208}
209
210//===----------------------------------------------------------------------===//
211// BitcastOp
212//===----------------------------------------------------------------------===//
213
214bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
215 if (inputs.size() != 1 || outputs.size() != 1)
216 return false;
217 Type a = inputs.front(), b = outputs.front();
218 auto aT = dyn_cast<TensorType>(a);
219 auto bT = dyn_cast<TensorType>(b);
220 if (!aT || !bT)
221 return false;
222
223 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
224 return false;
225
226 return succeeded(verifyCompatibleShape(aT, bT));
227}
228
229namespace {
230
231/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
232/// operation.
233struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
234 using OpRewritePattern<BitcastOp>::OpRewritePattern;
235
236 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 PatternRewriter &rewriter) const final {
238 auto tensorBitcastOperand =
239 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
240 if (!tensorBitcastOperand)
241 return failure();
242
243 auto resultType = cast<TensorType>(tensorBitcast.getType());
244 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
245 tensorBitcastOperand.getOperand());
246 return success();
247 }
248};
249
250} // namespace
251
252void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
253 MLIRContext *context) {
254 results.add<ChainedTensorBitcast>(context);
255}
256
257//===----------------------------------------------------------------------===//
258// CastOp
259//===----------------------------------------------------------------------===//
260
261void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
262 setNameFn(getResult(), "cast");
263}
264
265/// Returns true if `target` is a ranked tensor type that preserves static
266/// information available in the `source` ranked tensor type.
268 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
269 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
270
271 // Requires RankedTensorType.
272 if (!sourceType || !targetType)
273 return false;
274
275 // Requires same elemental type.
276 if (sourceType.getElementType() != targetType.getElementType())
277 return false;
278
279 // Requires same rank.
280 if (sourceType.getRank() != targetType.getRank())
281 return false;
282
283 // Requires same encoding.
284 if (sourceType.getEncoding() != targetType.getEncoding())
285 return false;
286
287 // If cast is towards more static sizes along any dimension, don't fold.
288 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
289 if (ShapedType::isStatic(std::get<0>(t)) &&
290 ShapedType::isDynamic(std::get<1>(t)))
291 return false;
292 }
293
294 return true;
295}
296
297/// Determines whether tensor::CastOp casts to a more dynamic version of the
298/// source tensor. This is useful to fold a tensor.cast into a consuming op and
299/// implement canonicalization patterns for ops in different dialects that may
300/// consume the results of tensor.cast operations. Such foldable tensor.cast
301/// operations are typically inserted as `slice` ops and are canonicalized,
302/// to preserve the type compatibility of their uses.
303///
304/// Returns true when all conditions are met:
305/// 1. source and result are ranked tensors with same element type and rank.
306/// 2. the tensor type has more static information than the result
307///
308/// Example:
309/// ```mlir
310/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
311/// %2 = consumer %1 ... : tensor<?x?xf32> ...
312/// ```
313///
314/// folds into:
315///
316/// ```mlir
317/// %2 = consumer %0 ... : tensor<8x16xf32> ...
318/// ```
320 if (!castOp)
321 return false;
322
323 // Can fold if the source of cast has at least as much static information as
324 // its results.
325 return preservesStaticInformation(castOp.getType(),
326 castOp.getSource().getType());
327}
328
329/// Determines whether the tensor::CastOp casts to a more static version of the
330/// source tensor. This is useful to fold into a producing op and implement
331/// canonicalization patterns with the `tensor.cast` op as the root, but
332/// producer being from different dialects. Returns true when all conditions are
333/// met:
334/// 1. source and result and ranked tensors with same element type and rank.
335/// 2. the result type has more static information than the source.
336///
337/// Example:
338/// ```mlir
339/// %1 = producer ... : tensor<?x?xf32>
340/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
341/// ```
342///
343/// can be canonicalized to :
344///
345/// ```mlir
346/// %2 = producer ... : tensor<8x16xf32>
347/// ```
348/// Not all ops might be canonicalizable this way, but for those that can be,
349/// this method provides a check that it is worth doing the canonicalization.
351 if (!castOp)
352 return false;
353 return preservesStaticInformation(castOp.getSource().getType(),
354 castOp.getType());
355}
356
358 return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
359 if (llvm::isa<BlockArgument>(opOperand.get()))
360 return false;
361 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
362 return castOp && canFoldIntoConsumerOp(castOp);
363 });
364}
365
367 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
368 SmallVector<Value> newOperands;
369 newOperands.reserve(op->getNumOperands());
370
371 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
372
373 // Assumes that the result has dpsInits followed by nonDpsInits.
374 int64_t dpsInitIdx = 0;
375 for (OpOperand &opOperand : op->getOpOperands()) {
376 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377 bool fold = canFoldIntoConsumerOp(tensorCastOp);
378 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
379 if (op.isDpsInit(&opOperand) &&
380 !llvm::isa<MemRefType>(newOperands.back().getType()))
381 newResTy[dpsInitIdx++] = newOperands.back().getType();
382 }
383 return newOperands;
384}
385
386/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
387/// that can be folded.
389 bool folded = false;
390 for (OpOperand &operand : op->getOpOperands()) {
391 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
392 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
393 operand.set(castOp.getOperand());
394 folded = true;
395 }
396 }
397 return success(folded);
398}
399
400bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
401 if (inputs.size() != 1 || outputs.size() != 1)
402 return false;
403 Type a = inputs.front(), b = outputs.front();
404 auto aT = llvm::dyn_cast<TensorType>(a);
405 auto bT = llvm::dyn_cast<TensorType>(b);
406 if (!aT || !bT)
407 return false;
408
409 if (aT.getElementType() != bT.getElementType())
410 return false;
411
412 return succeeded(verifyCompatibleShape(aT, bT));
413}
414
415/// Compute a TensorType that has the joined shape knowledge of the two
416/// given TensorTypes. The element types need to match.
418 assert(one.getElementType() == two.getElementType());
419
420 if (!one.hasRank())
421 return two;
422 if (!two.hasRank())
423 return one;
424
425 int64_t rank = one.getRank();
426 if (rank != two.getRank())
427 return {};
428
430 join.reserve(rank);
431 for (int64_t i = 0; i < rank; ++i) {
432 if (one.isDynamicDim(i)) {
433 join.push_back(two.getDimSize(i));
434 continue;
435 }
436 if (two.isDynamicDim(i)) {
437 join.push_back(one.getDimSize(i));
438 continue;
439 }
440 if (one.getDimSize(i) != two.getDimSize(i))
441 return {};
442 join.push_back(one.getDimSize(i));
443 }
444 return RankedTensorType::get(join, one.getElementType());
445}
446
447namespace {
448
449/// Replaces chains of two tensor.cast operations by a single tensor.cast
450/// operation if doing so does not remove runtime constraints.
451struct ChainedTensorCast : public OpRewritePattern<CastOp> {
452 using OpRewritePattern<CastOp>::OpRewritePattern;
453
454 LogicalResult matchAndRewrite(CastOp tensorCast,
455 PatternRewriter &rewriter) const final {
456 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
457
458 if (!tensorCastOperand)
459 return failure();
460
461 auto sourceType =
462 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
463 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
464 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
465
466 // We can remove the intermediate cast if joining all three produces the
467 // same result as just joining the source and result shapes.
468 auto firstJoin =
469 joinShapes(joinShapes(sourceType, intermediateType), resultType);
470
471 // The join might not exist if the cast sequence would fail at runtime.
472 if (!firstJoin)
473 return failure();
474
475 // The newJoin always exists if the above join exists, it might just contain
476 // less information. If so, we cannot drop the intermediate cast, as doing
477 // so would remove runtime checks.
478 auto newJoin = joinShapes(sourceType, resultType);
479 if (firstJoin != newJoin)
480 return failure();
481
482 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
483 tensorCastOperand.getOperand());
484 return success();
485 }
486};
487
488/// Fold tensor.cast into tesor.extract_slice producer.
489/// Example:
490/// ```
491/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
492/// tensor<128x512xf32> to tensor<?x512xf32>
493/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
494/// ```
495/// ->
496/// ```
497/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
498/// tensor<128x512xf32> to tensor<16x512xf32>
499/// ```
500struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
501 using OpRewritePattern<CastOp>::OpRewritePattern;
502
503 LogicalResult matchAndRewrite(CastOp tensorCast,
504 PatternRewriter &rewriter) const final {
505 auto extractOperand =
506 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
507
508 // Cannot fold cast to unranked tensor.
509 auto rankedResultType =
510 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
511 if (!rankedResultType)
512 return failure();
513
514 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
515 rankedResultType.getShape() ==
516 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
517 .getShape())
518 return failure();
519
520 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
521 auto dimMask = computeRankReductionMask(
522 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
523 size_t dimIndex = 0;
524 for (size_t i = 0, e = sizes.size(); i < e; i++) {
525 if (dimMask && dimMask->count(i))
526 continue;
527 int64_t dim = rankedResultType.getShape()[dimIndex++];
528 if (ShapedType::isDynamic(dim))
529 continue;
530 sizes[i] = rewriter.getIndexAttr(dim);
531 }
532
533 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
534 tensorCast, rankedResultType, extractOperand.getSource(),
535 extractOperand.getMixedOffsets(), sizes,
536 extractOperand.getMixedStrides());
537 return success();
538 }
539};
540
541} // namespace
542
543void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
544 MLIRContext *context) {
545 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
546}
547
548//===----------------------------------------------------------------------===//
549// ConcatOp
550//===----------------------------------------------------------------------===//
551
552RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
553 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
554 auto tensorTypes =
555 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
556 int64_t concatRank = tensorTypes[0].getRank();
557
558 // The concatenation dim must be in the range [0, rank).
559 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
560
561 SmallVector<int64_t> sizes(concatRank);
562 for (int64_t i = 0, e = concatRank; i < e; ++i) {
563 if (i == dim)
564 continue;
565 SaturatedInteger size;
566 for (auto tensorType : tensorTypes)
567 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
568 sizes[i] = size.asInteger();
569 }
570 auto concatSize = SaturatedInteger::wrap(0);
571 for (auto tensorType : tensorTypes)
572 concatSize =
573 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
574 sizes[dim] = concatSize.asInteger();
575 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
576}
577
578void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
579 ValueRange inputs) {
580 FailureOr<RankedTensorType> resultType =
581 inferResultType(dim, inputs.getTypes());
582 assert(succeeded(resultType) && "failed to infer concatenation result type");
583 build(builder, result, *resultType, dim, inputs);
584}
585
586LogicalResult ConcatOp::verify() {
587 if (getInputs().size() < 1)
588 return emitOpError("requires at least one input");
589
591 for (auto input : getInputs())
592 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
593
594 RankedTensorType resultType = getResultType();
595 int64_t resultRank = getRank();
596 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
597 return type.getRank() != resultRank;
598 }))
599 return emitOpError("rank of concatenated inputs must match result rank");
600
601 Type resultElementType = resultType.getElementType();
602 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
603 return type.getElementType() != resultElementType;
604 }))
605 return emitOpError("inputs and result element type must match");
606
607 int64_t dim = getDim();
608 if (dim >= resultRank)
609 return emitOpError("concatenation dim must be less than the tensor rank");
610
611 SmallVector<int64_t> sizes(resultRank);
612 for (int64_t i = 0, e = resultRank; i < e; ++i) {
613 if (i == dim)
614 continue;
615 SaturatedInteger size;
616 for (auto tensorType : inputTypes) {
617 FailureOr<SaturatedInteger> maybeSize =
618 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
619 if (failed(maybeSize))
620 return emitOpError("static concatenation size mismatch along ")
621 << "non-concatenated dimension " << i;
622 size = *maybeSize;
623 }
624 sizes[i] = size.asInteger();
625 }
626 auto concatSize = SaturatedInteger::wrap(0);
627 for (auto tensorType : inputTypes)
628 concatSize =
629 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
630 sizes[dim] = concatSize.asInteger();
631 auto inferredResultType =
632 RankedTensorType::get(sizes, inputTypes[0].getElementType());
633
634 for (auto [inferredSize, actualSize] :
635 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
636 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
637 ShapedType::isDynamic(actualSize);
638 if (!hasDynamic && inferredSize != actualSize)
639 return emitOpError("result type ")
640 << resultType << "does not match inferred shape "
641 << inferredResultType << " static sizes";
642 }
643
644 return success();
645}
646
647FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
648 size_t numInputs = getInputs().size();
649 uint64_t concatDim = getDim();
650
652 inputShapes.reserve(numInputs);
653 SmallVector<OpFoldResult> concatOffsets;
654 concatOffsets.reserve(numInputs);
655 SmallVector<OpFoldResult> outputShape;
656
657 AffineExpr addExpr =
658 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
659 OpFoldResult zero = builder.getIndexAttr(0);
660 Location loc = getLoc();
661 for (auto [index, input] : llvm::enumerate(getInputs())) {
662 SmallVector<OpFoldResult> inputShape =
663 tensor::getMixedSizes(builder, input.getLoc(), input);
664 if (index == 0) {
665 outputShape = inputShape;
666 concatOffsets.push_back(zero);
667 } else {
668 concatOffsets.push_back(outputShape[concatDim]);
669 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
670 builder, loc, addExpr,
671 {outputShape[concatDim], inputShape[concatDim]});
672 }
673 inputShapes.emplace_back(std::move(inputShape));
674 }
675
676 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
678
679 int64_t rank = getType().getRank();
680 OpFoldResult one = builder.getIndexAttr(1);
681 SmallVector<OpFoldResult> strides(rank, one);
682 SmallVector<OpFoldResult> offsets(rank, zero);
683 for (auto [index, input] : llvm::enumerate(getInputs())) {
684 offsets[concatDim] = concatOffsets[index];
685 auto insertSlice = tensor::InsertSliceOp::create(
686 builder, loc, input, replacement, offsets, inputShapes[index], strides);
687 replacement = insertSlice.getResult();
688 }
689 if (replacement.getType() != getType()) {
690 replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
691 }
693}
694
695LogicalResult
696ConcatOp::reifyResultShapes(OpBuilder &builder,
697 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
698 ValueRange inputs = getInputs();
699 int64_t dim = getDim();
700 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
701
702 Value init = inputs[0];
703 int64_t rank = getType().getRank();
704
705 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
706
707 // Pre-populate the result sizes with as much static information as possible
708 // from the given result type, as well as the inferred result type, otherwise
709 // use the dim sizes from the first input.
710 for (int64_t i = 0; i < rank; ++i) {
711 if (i == dim)
712 continue;
713 if (!getType().isDynamicDim(i)) {
714 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
715 } else if (!inferredResultType.isDynamicDim(i)) {
716 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
717 builder, getLoc(),
718 builder.getIndexAttr(inferredResultType.getDimSize(i)));
719 } else {
720 reifiedReturnShapes[0][i] =
721 tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
722 }
723 }
724
725 if (getType().isDynamicDim(dim)) {
726 // Take the sum of the input sizes along the concatenated dim.
727 AffineExpr sum = builder.getAffineDimExpr(0);
729 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
730 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
731 sum = sum + builder.getAffineDimExpr(idx + 1);
732 sizes.push_back(
733 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
734 }
735 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
736 builder, getLoc(),
737 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
738 } else {
739 // If the result shape is static along the concatenated dim, use the static
740 // shape.
741 reifiedReturnShapes[0][dim] =
742 builder.getIndexAttr(getType().getDimSize(dim));
743 }
744 return success();
745}
746
747void ConcatOp::getAsmResultNames(
748 function_ref<void(Value, StringRef)> setNameFn) {
749 setNameFn(getResult(), "concat");
750}
751
752OpFoldResult ConcatOp::fold(FoldAdaptor) {
753 ValueRange inputs = getInputs();
754 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
755 return inputs[0];
756 return {};
757}
758
759namespace {
760/// Fold a concat op with a single input to a cast.
761struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
762 using OpRewritePattern<ConcatOp>::OpRewritePattern;
763
764 LogicalResult matchAndRewrite(ConcatOp concatOp,
765 PatternRewriter &rewriter) const override {
766 if (concatOp.getInputs().size() != 1)
767 return failure();
768 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
769 concatOp.getInputs()[0]);
770 return success();
771 }
772};
773
774/// Propagate static shapes into the operands of a `tensor.concat`.
775///
776/// `tensor.concat` requires every operand to match on all dimensions except the
777/// concatenation dimension. If one operand is already static in those
778/// dimensions, the other operands may safely be refined to that same static
779/// shape.
780///
781/// Example:
782///
783/// ```mlir
784/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
785/// tensor<?x12xi32>
786/// ```
787/// ->
788/// ```mlir
789/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
790/// %2 = tensor.concat dim(0) %0, %cast :
791/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
792/// ```
793struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
794 using OpRewritePattern<ConcatOp>::OpRewritePattern;
795
796 LogicalResult matchAndRewrite(ConcatOp concatOp,
797 PatternRewriter &rewriter) const override {
798 int64_t dim = concatOp.getDim();
799 RankedTensorType inferredResultType =
800 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
801
802 // Find operands for which a more static shape can be inferred.
803 LogicalResult matched = failure();
804 // Inferred operand shapes are identical in every dimension except the
805 // concatenation dimension.
806 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
807 for (auto [operandIdx, operandType] :
808 llvm::enumerate(concatOp->getOperandTypes())) {
809 // Compute inferred type for operand.
810 inferredOperandShape[dim] =
811 cast<RankedTensorType>(operandType).getDimSize(dim);
812 auto inferredOperandType = RankedTensorType::get(
813 inferredOperandShape, inferredResultType.getElementType());
814
815 // Check if inferred type is more static.
816 if (!preservesStaticInformation(inferredOperandType, operandType)) {
817 matched = success();
818
819 // Use refined operand type and create cast from original operand.
820 auto castOp =
821 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
822 concatOp.getOperand(operandIdx));
823 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
824 concatOp->setOperand(operandIdx, castOp->getResult(0));
825 });
826 }
827 }
828
829 return matched;
830 }
831};
832
833// Ensure `tensor.concat`'s result type is at least as static as can be inferred
834// from its operand types.
835///
836/// Example:
837/// ```mlir
838/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
839/// tensor<?x?xi32>
840/// ```
841/// ->
842/// ```mlir
843/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
844/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
845/// tensor<?x?xi32>
846/// ```
847struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
848 using OpRewritePattern<ConcatOp>::OpRewritePattern;
849
850 LogicalResult matchAndRewrite(ConcatOp concatOp,
851 PatternRewriter &rewriter) const override {
852 int64_t dim = concatOp.getDim();
853 RankedTensorType inferredResultType =
854 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
855
856 // The result type should be at least as static as inferred result type.
857 if (preservesStaticInformation(inferredResultType,
858 concatOp.getResultType())) {
859 return failure();
860 }
861
862 auto newConcatOp =
863 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
864 concatOp->getOperands());
865 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
866 newConcatOp);
867
868 return success();
869 }
870};
871} // namespace
872
873void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
874 MLIRContext *context) {
875 results
876 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
877 context);
878}
879
880//===----------------------------------------------------------------------===//
881// DimOp
882//===----------------------------------------------------------------------===//
883
884void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
885 setNameFn(getResult(), "dim");
886}
887
888void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
889 int64_t index) {
890 auto loc = result.location;
891 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
892 build(builder, result, source, indexValue);
893}
894
895std::optional<int64_t> DimOp::getConstantIndex() {
897}
898
899Speculation::Speculatability DimOp::getSpeculatability() {
900 auto constantIndex = getConstantIndex();
901 if (!constantIndex)
903
904 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
905 if (!rankedSourceType)
907
908 if (rankedSourceType.getRank() <= constantIndex)
910
912}
913
914void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
915 SetIntLatticeFn setResultRange) {
916 setResultRange(getResult(),
917 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
918}
919
920OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
921 // All forms of folding require a known index.
922 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
923 if (!index)
924 return {};
925
926 // Folding for unranked types (UnrankedTensorType) is not supported.
927 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
928 if (!tensorType)
929 return {};
930
931 // Out of bound indices produce undefined behavior but are still valid IR.
932 // Don't choke on them.
933 int64_t indexVal = index.getInt();
934 if (indexVal < 0 || indexVal >= tensorType.getRank())
935 return {};
936
937 // Fold if the shape extent along the given index is known.
938 if (!tensorType.isDynamicDim(index.getInt())) {
939 Builder builder(getContext());
940 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
941 }
942
943 Operation *definingOp = getSource().getDefiningOp();
944
945 // Fold dim to the operand of tensor.generate.
946 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
947 auto resultType =
948 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
949 // The case where the type encodes the size of the dimension is handled
950 // above.
951 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
952
953 // Find the operand of the fromElements that corresponds to this index.
954 auto dynExtents = fromElements.getDynamicExtents().begin();
955 for (auto dim : resultType.getShape().take_front(index.getInt()))
956 if (ShapedType::isDynamic(dim))
957 dynExtents++;
958
959 return Value{*dynExtents};
960 }
961
962 // The size at the given index is now known to be a dynamic size.
963 unsigned unsignedIndex = index.getValue().getZExtValue();
964
965 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
966 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
967 // `resolve-shaped-type-result-dims` pass.
968 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
969 sliceOp.isDynamicSize(unsignedIndex)) {
970 return {sliceOp.getDynamicSize(unsignedIndex)};
971 }
972 }
973
974 // dim(cast) -> dim
975 if (succeeded(foldTensorCast(*this)))
976 return getResult();
977
978 return {};
979}
980
981namespace {
982/// Fold dim of a cast into the dim of the source of the tensor cast.
983struct DimOfCastOp : public OpRewritePattern<DimOp> {
984 using OpRewritePattern<DimOp>::OpRewritePattern;
985
986 LogicalResult matchAndRewrite(DimOp dimOp,
987 PatternRewriter &rewriter) const override {
988 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
989 if (!castOp)
990 return failure();
991 Value newSource = castOp.getOperand();
992 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
993 return success();
994 }
995};
996
997/// Fold dim of a destination passing style op into the dim of the corresponding
998/// init.
999struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1000 using OpRewritePattern<DimOp>::OpRewritePattern;
1001
1002 LogicalResult matchAndRewrite(DimOp dimOp,
1003 PatternRewriter &rewriter) const override {
1004 auto source = dimOp.getSource();
1005 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1006 if (!destOp)
1007 return failure();
1008
1009 auto resultIndex = cast<OpResult>(source).getResultNumber();
1010 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1011
1012 rewriter.modifyOpInPlace(
1013 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1014 return success();
1015 }
1016};
1017
1018/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1019/// operand.
1020struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1021 using OpRewritePattern<DimOp>::OpRewritePattern;
1022
1023 LogicalResult matchAndRewrite(DimOp dim,
1024 PatternRewriter &rewriter) const override {
1025 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1026
1027 if (!reshape)
1028 return failure();
1029
1030 // Since tensors are immutable we don't need to worry about where to place
1031 // the extract call
1032 rewriter.setInsertionPointAfter(dim);
1033 Location loc = dim.getLoc();
1034 Value extract =
1035 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1036 if (extract.getType() != dim.getType())
1037 extract =
1038 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1039 rewriter.replaceOp(dim, extract);
1040 return success();
1041 }
1042};
1043} // namespace
1044
1045void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1046 MLIRContext *context) {
1047 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// EmptyOp
1052//===----------------------------------------------------------------------===//
1053
1054void EmptyOp::build(OpBuilder &builder, OperationState &result,
1055 ArrayRef<int64_t> staticShape, Type elementType,
1056 Attribute encoding) {
1057 assert(none_of(staticShape, ShapedType::isDynamic) &&
1058 "expected only static sizes");
1059 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1060}
1061
1062void EmptyOp::build(OpBuilder &builder, OperationState &result,
1063 ArrayRef<int64_t> staticShape, Type elementType,
1064 ValueRange dynamicSizes, Attribute encoding) {
1065 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1066 build(builder, result, tensorType, dynamicSizes);
1067}
1068
1069void EmptyOp::build(OpBuilder &builder, OperationState &result,
1070 ArrayRef<OpFoldResult> sizes, Type elementType,
1071 Attribute encoding) {
1072 SmallVector<int64_t> staticShape;
1073 SmallVector<Value> dynamicSizes;
1074 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1075 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1076}
1077
1078LogicalResult EmptyOp::verify() {
1079 return verifyDynamicDimensionCount(getOperation(), getType(),
1080 getDynamicSizes());
1081}
1082
1083LogicalResult
1084EmptyOp::reifyResultShapes(OpBuilder &builder,
1085 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1086 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1087 unsigned ctr = 0;
1088 for (int64_t i = 0; i < getType().getRank(); ++i) {
1089 if (getType().isDynamicDim(i)) {
1090 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1091 } else {
1092 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1093 }
1094 }
1095 return success();
1096}
1097
1098Value EmptyOp::getDynamicSize(unsigned idx) {
1099 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1100 unsigned ctr = 0;
1101 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1102 if (getType().isDynamicDim(i))
1103 ++ctr;
1104 return getDynamicSizes()[ctr];
1105}
1106
1107SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1108 SmallVector<OpFoldResult> result;
1109 unsigned ctr = 0;
1110 OpBuilder b(getContext());
1111 for (int64_t i = 0; i < getType().getRank(); ++i) {
1112 if (getType().isDynamicDim(i)) {
1113 result.push_back(getDynamicSizes()[ctr++]);
1114 } else {
1115 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1116 }
1117 }
1118 return result;
1119}
1120
1121namespace {
1122/// Change the type of the result of a `tensor.empty` by making the result
1123/// type statically sized along dimensions that in the original operation were
1124/// defined as dynamic, but the size was defined using a `constant` op. For
1125/// example
1126///
1127/// %c5 = arith.constant 5: index
1128/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1129///
1130/// to
1131///
1132/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1133struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1134 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1135
1136 LogicalResult matchAndRewrite(EmptyOp op,
1137 PatternRewriter &rewriter) const override {
1138 SmallVector<Value> foldedDynamicSizes;
1139 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1140 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1141
1142 // Stop here if no dynamic size was promoted to static.
1143 if (foldedTensorType == op.getType())
1144 return failure();
1145
1146 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1147 foldedDynamicSizes);
1148 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1149 return success();
1150 }
1151};
1152
1153struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1154 using OpRewritePattern<DimOp>::OpRewritePattern;
1155
1156 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1157 PatternRewriter &rewriter) const override {
1158 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1159 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1160 if (!emptyTensorOp || !maybeConstantIndex)
1161 return failure();
1162 auto emptyTensorType = emptyTensorOp.getType();
1163 if (*maybeConstantIndex < 0 ||
1164 *maybeConstantIndex >= emptyTensorType.getRank() ||
1165 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1166 return failure();
1167 rewriter.replaceOp(dimOp,
1168 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1169 return success();
1170 }
1171};
1172
1173/// Canonicalize
1174///
1175/// ```mlir
1176/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1177/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1178/// ```
1179///
1180/// into
1181///
1182/// ```mlir
1183/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1184/// ```
1185///
1186/// This assumes the input program is correct in terms of its shape. So it is
1187/// safe to assume that `%d0` is in fact 4.
1188struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1189 using OpRewritePattern<CastOp>::OpRewritePattern;
1190
1191 LogicalResult matchAndRewrite(CastOp castOp,
1192 PatternRewriter &rewriter) const override {
1193 if (!canFoldIntoProducerOp(castOp))
1194 return failure();
1195 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1196 if (!producer)
1197 return failure();
1198
1199 auto resultType =
1200 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1201 ArrayRef<int64_t> resultShape = resultType.getShape();
1202 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1203 SmallVector<OpFoldResult> newMixedSizes;
1204 newMixedSizes.reserve(currMixedSizes.size());
1205 assert(resultShape.size() == currMixedSizes.size() &&
1206 "mismatch in result shape and sizes of empty op");
1207 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1208 int64_t newDim = std::get<0>(it);
1209 OpFoldResult currDim = std::get<1>(it);
1210 // Case 1: The empty tensor dim is static. Check that the tensor cast
1211 // result dim matches.
1212 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1213 if (ShapedType::isDynamic(newDim) ||
1214 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1215 // Something is off, the cast result shape cannot be more dynamic
1216 // than the empty tensor result shape (enforced by
1217 // `canFoldIntoProducer`). Abort for now.
1218 return rewriter.notifyMatchFailure(
1219 producer, "mismatch in static value of shape of empty tensor "
1220 "result and cast result");
1221 }
1222 newMixedSizes.push_back(attr);
1223 continue;
1224 }
1225
1226 // Case 2 : The tensor cast shape is static, but empty tensor result
1227 // shape is dynamic.
1228 if (ShapedType::isStatic(newDim)) {
1229 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1230 continue;
1231 }
1232
1233 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1234 // shape is dynamic. Use the dynamic value from the empty tensor op.
1235 newMixedSizes.push_back(currDim);
1236 }
1237
1238 // TODO: Do not drop tensor encoding.
1239 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1240 resultType.getElementType());
1241 return success();
1242 }
1243};
1244
1245} // namespace
1246
1247void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248 MLIRContext *context) {
1249 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1250 ReplaceEmptyTensorStaticShapeDims>(context);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// ExtractOp
1255//===----------------------------------------------------------------------===//
1256
1257namespace {
1258
1259/// Canonicalizes the pattern of the form
1260///
1261/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1262/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1263///
1264/// to
1265///
1266/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1267struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1268 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1269
1270 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1271 PatternRewriter &rewriter) const final {
1272 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1273 if (!tensorCast)
1274 return failure();
1275 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1276 return failure();
1277 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1278 extract, tensorCast.getSource(), extract.getIndices());
1279 return success();
1280 }
1281};
1282
1283/// Canonicalizes the pattern of the form
1284///
1285/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1286/// tensor<12xf64>
1287/// %extracted_element = tensor.extract %val[%c10] :
1288/// tensor<12xf64>
1289///
1290/// to
1291///
1292/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1293struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1294 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1295
1296 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1297 PatternRewriter &rewriter) const final {
1298 auto collapseOp =
1299 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1300 if (!collapseOp)
1301 return failure();
1302 if (!collapseOp.getSrcType().hasStaticShape())
1303 return failure();
1304
1305 auto sourceSizes = collapseOp.getSrcType().getShape();
1306
1307 SmallVector<Value> indices(extractOp.getIndices().begin(),
1308 extractOp.getIndices().end());
1309 SmallVector<Value> sourceIndices;
1310 for (auto [index, group] :
1311 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1312 assert(!group.empty() && "association indices groups cannot be empty");
1313 auto groupSize = group.size();
1314
1315 if (groupSize == 1) {
1316 sourceIndices.push_back(index);
1317 continue;
1318 }
1319
1320 SmallVector<int64_t> basis =
1321 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1322 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1323 rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1324 llvm::append_range(sourceIndices, delinearize.getResults());
1325 }
1326 if (collapseOp.getReassociationIndices().empty()) {
1327 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1328 int64_t srcRank =
1329 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1330 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1331 rewriter, extractOp.getLoc(), zeroAffineMap,
1332 ArrayRef<OpFoldResult>{});
1333 for (int64_t i = 0; i < srcRank; i++) {
1334 sourceIndices.push_back(
1335 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1336 }
1337 }
1338
1339 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1340 extractOp, collapseOp.getSrc(), sourceIndices);
1341 return success();
1342 }
1343};
1344
1345} // namespace
1346
1347void ExtractOp::getAsmResultNames(
1348 function_ref<void(Value, StringRef)> setNameFn) {
1349 setNameFn(getResult(), "extracted");
1350}
1351
1352LogicalResult ExtractOp::verify() {
1353 // Verify the # indices match if we have a ranked type.
1354 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1355 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1356 return emitOpError("incorrect number of indices for extract_element");
1357 return success();
1358}
1359
1360/// If we have an ExtractOp consuming an InsertOp with the same
1361/// indices, we can return the InsertOp's scalar directly.
1362// TODO: This only checks the immediate producer; extend to go up the
1363// insert/extract chain if the slices are disjoint.
1364static Value foldExtractAfterInsert(ExtractOp extractOp) {
1365 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1366
1367 auto isSame = [](Value a, Value b) {
1369 };
1370 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1371 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1372 return insertOp.getScalar();
1373
1374 return {};
1375}
1376
1377OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1378 if (Attribute tensor = adaptor.getTensor()) {
1379 // If this is a splat elements attribute, simply return the value.
1380 // All of the elements of a splat attribute are the same.
1381 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1382 return splatTensor.getSplatValue<Attribute>();
1383
1384 // If this is a dense resource elements attribute, return.
1385 if (isa<DenseResourceElementsAttr>(tensor))
1386 return {};
1387 }
1388
1389 // Collect the constant indices into the tensor.
1390 SmallVector<uint64_t, 8> indices;
1391 for (Attribute indice : adaptor.getIndices()) {
1392 if (!indice || !llvm::isa<IntegerAttr>(indice))
1393 return {};
1394 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1395 }
1396
1397 // Fold extract(from_elements(...)).
1398 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1399 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1400 auto rank = tensorType.getRank();
1401 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1402 "rank mismatch");
1403 int flatIndex = 0;
1404 int stride = 1;
1405 for (int i = rank - 1; i >= 0; --i) {
1406 flatIndex += indices[i] * stride;
1407 stride *= tensorType.getDimSize(i);
1408 }
1409 // Prevent out of bounds accesses. This can happen in invalid code that
1410 // will never execute.
1411 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1412 flatIndex < 0)
1413 return {};
1414 return fromElementsOp.getElements()[flatIndex];
1415 }
1416
1417 // If this is an elements attribute, query the value at the given indices.
1418 if (Attribute tensor = adaptor.getTensor()) {
1419 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1420 if (elementsAttr && elementsAttr.isValidIndex(indices))
1421 return elementsAttr.getValues<Attribute>()[indices];
1422 }
1423
1424 if (Value result = foldExtractAfterInsert(*this))
1425 return result;
1426
1427 return {};
1428}
1429
1430void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1431 MLIRContext *context) {
1432 results.add<ExtractFromTensorCast>(context);
1433}
1434
1437 patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1438}
1439
1440//===----------------------------------------------------------------------===//
1441// FromElementsOp
1442//===----------------------------------------------------------------------===//
1443
1444void FromElementsOp::getAsmResultNames(
1445 function_ref<void(Value, StringRef)> setNameFn) {
1446 setNameFn(getResult(), "from_elements");
1447}
1448
1449void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1450 ValueRange elements) {
1451 assert(!elements.empty() && "expected at least one element");
1452 Type resultType = RankedTensorType::get(
1453 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1454 build(builder, result, resultType, elements);
1455}
1456
1457OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1458 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1459 return DenseElementsAttr::get(getType(), adaptor.getElements());
1460 return {};
1461}
1462
1463namespace {
1464
1465// Pushes the index_casts that occur before extractions to after the extract.
1466// This minimizes type conversion in some cases and enables the extract
1467// canonicalizer. This changes:
1468//
1469// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1470// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1471//
1472// to the following:
1473//
1474// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1475// %cast = arith.index_cast %extract : i32 to index
1476//
1477// to just %element.
1478//
1479// Consider expanding this to a template and handle all tensor cast
1480// operations.
1481struct ExtractElementFromIndexCast
1482 : public OpRewritePattern<tensor::ExtractOp> {
1483 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1484
1485 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1486 PatternRewriter &rewriter) const final {
1487 Location loc = extract.getLoc();
1488 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1489 if (!indexCast)
1490 return failure();
1491
1492 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1493
1494 auto newExtract = tensor::ExtractOp::create(
1495 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1496
1497 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1498 newExtract);
1499
1500 return success();
1501 }
1502};
1503
1504} // namespace
1505
1506void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1507 MLIRContext *context) {
1508 results.add<ExtractElementFromIndexCast>(context);
1509}
1510
1511//===----------------------------------------------------------------------===//
1512// GatherOp
1513//===----------------------------------------------------------------------===//
1514
1515void GatherOp::getAsmResultNames(
1516 function_ref<void(Value, StringRef)> setNameFn) {
1517 setNameFn(getResult(), "gather");
1518}
1519
1520/// Return the inferred result type for a gatherOp where:
1521/// - sourceType is the type of the source tensor gathered from
1522/// - indicesType is the type of the indices used to gather
1523/// - gatherDims are the dims along which the gather occurs.
1524/// Return a full rank or ranked-reduced variant of the type depending on
1525/// the value of rankReduced.
1526///
1527/// The leading dimensions of the index tensor give the result tensor its
1528/// leading dimensions.
1529/// The trailing dimensions of the result tensor are obtained from the source
1530/// tensor by setting the dimensions specified in gather_dims to `1` (if
1531/// rankedReduced is false), or skipping them (otherwise).
1532RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1533 RankedTensorType indicesType,
1534 ArrayRef<int64_t> gatherDims,
1535 bool rankReduced) {
1536 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1537 resultShape.reserve(resultShape.size() + sourceType.getRank());
1538 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1539 if (llvm::binary_search(gatherDims, idx)) {
1540 if (!rankReduced)
1541 resultShape.push_back(1);
1542 continue;
1543 }
1544 resultShape.push_back(sourceType.getDimSize(idx));
1545 }
1546 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1547}
1548
1549static LogicalResult
1552 StringRef gatherOrScatter, StringRef sourceOrDest) {
1553 if (dims.empty())
1554 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1555
1556 int64_t numGatherDims = dims.size();
1557 if (numGatherDims > rank)
1558 return op->emitOpError(gatherOrScatter)
1559 << "_dims overflow " << sourceOrDest << " rank";
1560 if (indices.empty() || indices.back() != numGatherDims)
1561 return op->emitOpError(gatherOrScatter)
1562 << "_dims length must match the size of last dimension of indices";
1563 for (int64_t val : dims) {
1564 if (val < 0)
1565 return op->emitOpError(gatherOrScatter)
1566 << "_dims value must be non-negative";
1567 if (val >= rank)
1568 return op->emitOpError(gatherOrScatter)
1569 << "_dims value must be smaller than " << sourceOrDest << " rank";
1570 }
1571 for (int64_t i = 1; i < numGatherDims; ++i) {
1572 if (dims[i - 1] >= dims[i])
1573 return op->emitOpError(gatherOrScatter)
1574 << "_dims values must be strictly increasing";
1575 }
1576 return success();
1577}
1578
1579LogicalResult GatherOp::verify() {
1580 int64_t sourceRank = getSourceType().getRank();
1581 ArrayRef<int64_t> gatherDims = getGatherDims();
1582 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1583 getIndicesType().getShape(), sourceRank,
1584 "gather", "source")))
1585 return failure();
1586
1587 RankedTensorType expectedResultType = GatherOp::inferResultType(
1588 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1589 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1591 if (getResultType() != expectedResultType &&
1592 getResultType() != expectedRankReducedResultType) {
1593 return emitOpError("result type "
1594 "mismatch: "
1595 "expected ")
1596 << expectedResultType << " or its rank-reduced variant "
1597 << expectedRankReducedResultType << " (got: " << getResultType()
1598 << ")";
1599 }
1600
1601 return success();
1602}
1603
1604OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1605 if (OpFoldResult reshapedSource = reshapeConstantSource(
1606 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1607 getResult().getType()))
1608 return reshapedSource;
1609 return {};
1610}
1611
1612//===----------------------------------------------------------------------===//
1613// InsertOp
1614//===----------------------------------------------------------------------===//
1615
1616void InsertOp::getAsmResultNames(
1617 function_ref<void(Value, StringRef)> setNameFn) {
1618 setNameFn(getResult(), "inserted");
1619}
1620
1621LogicalResult InsertOp::verify() {
1622 // Verify the # indices match if we have a ranked type.
1623 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1624 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1625 return emitOpError("incorrect number of indices");
1626 return success();
1627}
1628
1629OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1630 Attribute scalar = adaptor.getScalar();
1631 Attribute dest = adaptor.getDest();
1632 if (scalar && dest)
1633 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1634 if (scalar == splatDest.getSplatValue<Attribute>())
1635 return dest;
1636 return {};
1637}
1638
1639//===----------------------------------------------------------------------===//
1640// GenerateOp
1641//===----------------------------------------------------------------------===//
1642
1643void GenerateOp::getAsmResultNames(
1644 function_ref<void(Value, StringRef)> setNameFn) {
1645 setNameFn(getResult(), "generated");
1646}
1647
1648LogicalResult GenerateOp::reifyResultShapes(
1649 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1650 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1651 int idx = 0;
1652 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1653 if (getType().isDynamicDim(dim)) {
1654 reifiedReturnShapes[0][dim] = getOperand(idx++);
1655 } else {
1656 reifiedReturnShapes[0][dim] =
1657 builder.getIndexAttr(getType().getDimSize(dim));
1658 }
1659 }
1660 return success();
1661}
1662
1663LogicalResult GenerateOp::verify() {
1664 // Ensure that the tensor type has as many dynamic dimensions as are
1665 // specified by the operands.
1666 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1667 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
1668 getOperands())))
1669 return failure();
1670 return success();
1671}
1672
1673LogicalResult GenerateOp::verifyRegions() {
1674 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1675 // Ensure that region arguments span the index space.
1676 if (!llvm::all_of(getBody().getArgumentTypes(),
1677 [](Type ty) { return ty.isIndex(); }))
1678 return emitError("all body arguments must be index");
1679 if (getBody().getNumArguments() != resultTy.getRank())
1680 return emitError("must have one body argument per input dimension");
1681
1682 // Ensure that the region yields an element of the right type.
1683 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1684
1685 if (yieldOp.getValue().getType() != resultTy.getElementType())
1686 return emitOpError(
1687 "body must be terminated with a `yield` operation of the tensor "
1688 "element type");
1689
1690 return success();
1691}
1692
1693void GenerateOp::build(
1694 OpBuilder &b, OperationState &result, Type resultTy,
1695 ValueRange dynamicExtents,
1696 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1697 build(b, result, resultTy, dynamicExtents);
1698
1699 // Build and populate body.
1700 OpBuilder::InsertionGuard guard(b);
1701 Region *bodyRegion = result.regions.front().get();
1702 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1703 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1704 SmallVector<Location, 2> argumentLocs(rank, result.location);
1705 Block *bodyBlock =
1706 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1707 bodyBuilder(b, result.location, bodyBlock->getArguments());
1708}
1709
1710namespace {
1711
1712/// Canonicalizes tensor.generate operations with a constant
1713/// operand into the equivalent operation with the operand expressed in the
1714/// result type, instead. We also insert a type cast to make sure that the
1715/// resulting IR is still well-typed.
1716struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1717 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1718
1719 LogicalResult matchAndRewrite(GenerateOp generateOp,
1720 PatternRewriter &rewriter) const final {
1721 SmallVector<Value> foldedDynamicSizes;
1722 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1723 generateOp.getType(), generateOp.getDynamicExtents(),
1724 foldedDynamicSizes);
1725
1726 // Stop here if no dynamic size was promoted to static.
1727 if (foldedTensorType == generateOp.getType())
1728 return failure();
1729
1730 auto loc = generateOp.getLoc();
1731 auto newOp =
1732 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1733 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1734 newOp.getBody().begin());
1735 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1736 generateOp.getType(), newOp);
1737 return success();
1738 }
1739};
1740
1741/// Canonicalizes the pattern of the form
1742///
1743/// %tensor = tensor.generate %x {
1744/// ^bb0(%arg0: index):
1745/// <computation>
1746/// yield %1 : index
1747/// } : tensor<?xindex>
1748/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1749///
1750/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1751/// tensor.generate operation has no side-effects.
1752struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1753 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1754
1755 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1756 PatternRewriter &rewriter) const final {
1757 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1758 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1759 return failure();
1760
1761 IRMapping mapping;
1762 Block *body = &tensorFromElements.getBody().front();
1763 mapping.map(body->getArguments(), extract.getIndices());
1764 for (auto &op : body->without_terminator())
1765 rewriter.clone(op, mapping);
1766
1767 auto yield = cast<YieldOp>(body->getTerminator());
1768
1769 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1770 return success();
1771 }
1772};
1773
1774} // namespace
1775
1776void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777 MLIRContext *context) {
1778 // TODO: Move extract pattern to tensor::ExtractOp.
1779 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// RankOp
1784//===----------------------------------------------------------------------===//
1785
1786void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1787 setNameFn(getResult(), "rank");
1788}
1789
1790OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1791 // Constant fold rank when the rank of the operand is known.
1792 auto type = getOperand().getType();
1793 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1794 if (shapedType && shapedType.hasRank())
1795 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1796 return IntegerAttr();
1797}
1798
1799//===----------------------------------------------------------------------===//
1800// ReshapeOp
1801//===----------------------------------------------------------------------===//
1802
1803void ReshapeOp::getAsmResultNames(
1804 function_ref<void(Value, StringRef)> setNameFn) {
1805 setNameFn(getResult(), "reshape");
1806}
1807
1808static int64_t getNumElements(ShapedType type) {
1809 int64_t numElements = 1;
1810 for (auto dim : type.getShape())
1811 numElements *= dim;
1812 return numElements;
1813}
1814
1815LogicalResult ReshapeOp::verify() {
1816 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1817 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1818
1819 if (operandType.getElementType() != resultType.getElementType())
1820 return emitOpError("element types of source and destination tensor "
1821 "types should be the same");
1822
1823 int64_t shapeSize =
1824 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1825 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1826 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1827
1828 if (resultRankedType) {
1829 if (operandRankedType && resultRankedType.hasStaticShape() &&
1830 operandRankedType.hasStaticShape()) {
1831 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1832 return emitOpError("source and destination tensor should have the "
1833 "same number of elements");
1834 }
1835 if (ShapedType::isDynamic(shapeSize))
1836 return emitOpError("cannot use shape operand with dynamic length to "
1837 "reshape to statically-ranked tensor type");
1838 if (shapeSize != resultRankedType.getRank())
1839 return emitOpError(
1840 "length of shape operand differs from the result's tensor rank");
1841 }
1842 return success();
1843}
1844
1845OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1846 if (OpFoldResult reshapedSource = reshapeConstantSource(
1847 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1848 getResult().getType()))
1849 return reshapedSource;
1850
1851 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1852 // producer's input instead as the original tensor to reshape. This could
1853 // render such producer dead code.
1854 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1855 getSourceMutable().assign(reshapeOpProducer.getSource());
1856 return getResult();
1857 }
1858
1859 auto source = getSource();
1860 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1861 auto resultTy = dyn_cast<RankedTensorType>(getType());
1862 if (!sourceTy || !resultTy || sourceTy != resultTy)
1863 return {};
1864
1865 // If the source and result are both 0D or 1D tensors and have the same type,
1866 // the reshape has no effect, even if the tensor is dynamically shaped.
1867 if (sourceTy.getRank() <= 1)
1868 return source;
1869
1870 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1871 auto elements = fromElements.getElements();
1872 bool dynamicNoop =
1873 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1874 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1875 auto element = elements[id];
1876
1877 if (auto cst = getConstantIntValue(element)) {
1878 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1879 continue;
1880 }
1881
1882 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1883 dynamicNoop &= dimOp.getSource() == source;
1884
1885 auto cst = getConstantIntValue(dimOp.getIndex());
1886 dynamicNoop &=
1887 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1888 continue;
1889 }
1890
1891 dynamicNoop = false;
1892 break;
1893 }
1894
1895 if (dynamicNoop)
1896 return source;
1897 }
1898
1899 return {};
1900}
1901
1902//===----------------------------------------------------------------------===//
1903// Reassociative reshape ops
1904//===----------------------------------------------------------------------===//
1905
1906void CollapseShapeOp::getAsmResultNames(
1907 function_ref<void(Value, StringRef)> setNameFn) {
1908 setNameFn(getResult(), "collapsed");
1909}
1910
1911void ExpandShapeOp::getAsmResultNames(
1912 function_ref<void(Value, StringRef)> setNameFn) {
1913 setNameFn(getResult(), "expanded");
1914}
1915
1916int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1917 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1918 "invalid resultDim");
1919 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1920 if (llvm::is_contained(it.value(), resultDim))
1921 return it.index();
1922 llvm_unreachable("could not find reassociation group");
1923}
1924
1925FailureOr<SmallVector<OpFoldResult>>
1926ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1927 RankedTensorType expandedType,
1928 ArrayRef<ReassociationIndices> reassociation,
1929 ArrayRef<OpFoldResult> inputShape) {
1930 std::optional<SmallVector<OpFoldResult>> outputShape =
1931 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1932 inputShape);
1933 if (!outputShape)
1934 return failure();
1935 return *outputShape;
1936}
1937
1938SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1939 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1940}
1941
1942void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1943 Type resultType, Value src,
1944 ArrayRef<ReassociationIndices> reassociation,
1945 ArrayRef<OpFoldResult> outputShape) {
1946 auto [staticOutputShape, dynamicOutputShape] =
1947 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1948 build(builder, result, cast<RankedTensorType>(resultType), src,
1949 getReassociationIndicesAttribute(builder, reassociation),
1950 dynamicOutputShape, staticOutputShape);
1951}
1952
1953void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1954 Type resultType, Value src,
1955 ArrayRef<ReassociationIndices> reassociation) {
1956 SmallVector<OpFoldResult> inputShape =
1957 getMixedSizes(builder, result.location, src);
1958 auto tensorResultTy = cast<RankedTensorType>(resultType);
1959 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1960 builder, result.location, tensorResultTy, reassociation, inputShape);
1961 SmallVector<OpFoldResult> outputShapeOrEmpty;
1962 if (succeeded(outputShape)) {
1963 outputShapeOrEmpty = *outputShape;
1964 }
1965 build(builder, result, tensorResultTy, src, reassociation,
1966 outputShapeOrEmpty);
1967}
1968
1969SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1970 return getSymbolLessAffineMaps(getReassociationExprs());
1971}
1972SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1974 getReassociationIndices());
1975}
1976
1977SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1978 return getSymbolLessAffineMaps(getReassociationExprs());
1979}
1980SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1982 getReassociationIndices());
1983}
1984
1985RankedTensorType CollapseShapeOp::inferCollapsedType(
1986 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1987 return inferCollapsedType(
1989 type.getContext(), reassociation)));
1990}
1991
1992/// Compute the RankedTensorType obtained by applying `reassociation` to
1993/// `type`.
1994RankedTensorType
1995CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1996 ArrayRef<AffineMap> reassociation) {
1997 auto shape = type.getShape();
1998 SmallVector<int64_t, 4> newShape;
1999 newShape.reserve(reassociation.size());
2000
2001 // Use the fact that reassociation is valid to simplify the logic: only use
2002 // each map's rank.
2003 assert(isReassociationValid(reassociation) && "invalid reassociation");
2004 unsigned currentDim = 0;
2005 for (AffineMap m : reassociation) {
2006 unsigned dim = m.getNumResults();
2007 auto band = shape.slice(currentDim, dim);
2008 int64_t size = 1;
2009 if (llvm::is_contained(band, ShapedType::kDynamic))
2010 size = ShapedType::kDynamic;
2011 else
2012 for (unsigned d = 0; d < dim; ++d)
2013 size *= shape[currentDim + d];
2014 newShape.push_back(size);
2015 currentDim += dim;
2016 }
2017
2018 return RankedTensorType::get(newShape, type.getElementType());
2019}
2020
2021void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2022 ArrayRef<ReassociationIndices> reassociation,
2023 ArrayRef<NamedAttribute> attrs) {
2024 auto srcType = llvm::cast<RankedTensorType>(src.getType());
2025 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2026 auto resultType =
2027 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2028 srcType.getEncoding());
2029 result.addAttribute(getReassociationAttrStrName(),
2030 getReassociationIndicesAttribute(b, reassociation));
2031 build(b, result, resultType, src, attrs);
2032}
2033
2034template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2035 TensorReshapeOp, ExpandShapeOp>::value>
2036static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2037 RankedTensorType expandedType,
2038 RankedTensorType collapsedType) {
2039 if (failed(
2040 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2041 return failure();
2042
2043 auto maps = op.getReassociationMaps();
2044 RankedTensorType expectedType =
2045 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2046 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2047 return op.emitOpError("expected collapsed type to be ")
2048 << expectedType << ", but got " << collapsedType;
2049 return success();
2050}
2051
2052LogicalResult ExpandShapeOp::verify() {
2053 auto srcType = getSrcType();
2054 auto resultType = getResultType();
2055
2056 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2057 return emitOpError("expected number of static shape dims to be equal to "
2058 "the output rank (")
2059 << resultType.getRank() << ") but found "
2060 << getStaticOutputShape().size() << " inputs instead";
2061
2062 if ((int64_t)getOutputShape().size() !=
2063 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2064 return emitOpError("mismatch in dynamic dims in output_shape and "
2065 "static_output_shape: static_output_shape has ")
2066 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2067 << " dynamic dims while output_shape has " << getOutputShape().size()
2068 << " values";
2069
2070 return verifyTensorReshapeOp(*this, resultType, srcType);
2071}
2072
2073LogicalResult CollapseShapeOp::verify() {
2074 CollapseShapeOp op = *this;
2075 if (llvm::any_of(op.getReassociationIndices(),
2076 [](ReassociationIndices group) { return group.empty(); })) {
2077 return op.emitOpError("reassociation indices must not be empty");
2078 }
2079 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2080}
2081
2082namespace {
2083/// Reshape of a splat constant can be replaced with a constant of the result
2084/// type.
2085template <typename TensorReshapeOp>
2086struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2087 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2088 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2089 PatternRewriter &rewriter) const override {
2090 DenseElementsAttr attr;
2091 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2092 return failure();
2093 if (!attr || !attr.isSplat())
2094 return failure();
2095 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2096 reshapeOp.getResultType(), attr.getRawData());
2097 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2098 return success();
2099 }
2100};
2101
2102// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2103template <typename TensorReshapeOp>
2104class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2105public:
2106 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2107
2108 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2109 PatternRewriter &rewriter) const override {
2110 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2111 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2112 return failure();
2113
2114 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2115 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2116 return success();
2117 }
2118};
2119
2120/// Reshape of a FromElements can be replaced with a FromElements of the
2121/// result type
2122template <typename TensorReshapeOp>
2123struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2124 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2125 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2126 PatternRewriter &rewriter) const override {
2127 auto fromElements =
2128 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2129 if (!fromElements)
2130 return failure();
2131
2132 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2133
2134 if (!shapedTy.hasStaticShape())
2135 return failure();
2136
2137 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2138 fromElements.getElements());
2139 return success();
2140 }
2141};
2142
2143// Fold CastOp into CollapseShapeOp when adding static information.
2144struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2145 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2146
2147 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2148 PatternRewriter &rewriter) const override {
2149 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2150 if (!tensor::canFoldIntoConsumerOp(castOp))
2151 return failure();
2152
2153 RankedTensorType srcType =
2154 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2155 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2156 srcType, collapseShapeOp.getReassociationMaps());
2157
2158 if (newResultType == collapseShapeOp.getResultType()) {
2159 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2160 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2161 });
2162 } else {
2163 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2164 newResultType, castOp.getSource(),
2165 collapseShapeOp.getReassociation());
2166 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2167 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2168 }
2169 return success();
2170 }
2171};
2172
2173/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2174/// matching constant output_shape operands of the expand. This makes the
2175/// `tensor.expand_shape` more static and creates a consumer cast that can be
2176/// propagated further.
2177struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2178 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2179
2180 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2181 PatternRewriter &rewriter) const override {
2182 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2183 if (!canFoldIntoConsumerOp(castOp))
2184 return failure();
2185
2186 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2187 SmallVector<ReassociationIndices, 4> reassoc =
2188 expandOp.getReassociationIndices();
2189
2190 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2191 SmallVector<Value> dynamicOutputShape;
2192 auto outputIt = expandOp.getOutputShape().begin();
2193
2194 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2195 for (uint64_t outDim : innerReassoc) {
2196 if (ShapedType::isStatic(newOutputShape[outDim]))
2197 continue;
2198
2199 // If the cast's src type is dynamic, don't infer any of the
2200 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2201 // least one of the expanded dimensions to be dynamic if the input is
2202 // dynamic.
2203 Value val = *outputIt;
2204 ++outputIt;
2205 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2206 dynamicOutputShape.push_back(val);
2207 continue;
2208 }
2209
2210 APInt cst;
2211 if (matchPattern(val, m_ConstantInt(&cst))) {
2212 newOutputShape[outDim] = cst.getSExtValue();
2213 } else {
2214 dynamicOutputShape.push_back(val);
2215 }
2216 }
2217 }
2218
2219 // Couldn't match any values, nothing to change
2220 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2221 return failure();
2222
2223 // Calculate the input shape from the output
2224 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2225 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2226 for (auto outDim : reassoc[inDim]) {
2227 auto ofr = newOutputShape[outDim];
2228 if (ShapedType::isDynamic(ofr)) {
2229 newInputShape[inDim] = ShapedType::kDynamic;
2230 break;
2231 }
2232 newInputShape[inDim] *= ofr;
2233 }
2234 }
2235
2236 SmallVector<OpFoldResult> outputOfr =
2237 getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2238 auto inputType = RankedTensorType::get(
2239 newInputShape, expandOp.getSrcType().getElementType());
2240 auto outputType = RankedTensorType::get(
2241 newOutputShape, expandOp.getSrcType().getElementType());
2242 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2243 expandOp.getSrc());
2244 auto newExpand = ExpandShapeOp::create(
2245 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2246 expandOp.getReassociationIndices(), outputOfr);
2247 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2248 newExpand.getResult());
2249 return success();
2250 }
2251};
2252} // namespace
2253
2254void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2255 MLIRContext *context) {
2256 results.add<
2257 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2258 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2259 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2260 FoldReshapeWithSplat<ExpandShapeOp>,
2261 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2262}
2263
2264void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2265 MLIRContext *context) {
2266 results.add<
2267 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2268 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2269 tensor::DimOp, RankedTensorType>,
2270 FoldReshapeWithConstant<CollapseShapeOp>,
2271 FoldReshapeWithSplat<CollapseShapeOp>,
2272 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2273 context);
2274}
2275
2276OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2278 adaptor.getOperands());
2279}
2280
2281OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2283 adaptor.getOperands());
2284}
2285
2286//===----------------------------------------------------------------------===//
2287// ExtractSliceOp
2288//===----------------------------------------------------------------------===//
2289
2290void ExtractSliceOp::getAsmResultNames(
2291 function_ref<void(Value, StringRef)> setNameFn) {
2292 setNameFn(getResult(), "extracted_slice");
2293}
2294
2295/// An extract_slice result type can be inferred, when it is not
2296/// rank-reduced, from the source type and the static representation of
2297/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2298RankedTensorType
2299ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2300 ArrayRef<int64_t> staticSizes) {
2301 // An extract_slice op may specify only a leading subset of offset/sizes/
2302 // strides in which case we complete with offset=0, sizes from memref type
2303 // and strides=1.
2304 assert(static_cast<int64_t>(staticSizes.size()) ==
2305 sourceTensorType.getRank() &&
2306 "unexpected staticSizes not equal to rank of source");
2307 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2308 sourceTensorType.getEncoding());
2309}
2310
2311// TODO: This uses neither offsets nor strides!
2312RankedTensorType
2313ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2314 ArrayRef<OpFoldResult> sizes) {
2315 SmallVector<int64_t> staticSizes;
2316 std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2317
2318 assert(static_cast<int64_t>(staticSizes.size()) ==
2319 sourceTensorType.getRank() &&
2320 "unexpected staticSizes not equal to rank of source");
2321 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2322 sourceTensorType.getEncoding());
2323}
2324
2325/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2326/// number of sizes), drop as many size 1 as needed to produce an inferred
2327/// type with the desired rank.
2328///
2329/// Note that there may be multiple ways to compute this rank-reduced type:
2330/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2331///
2332/// To disambiguate, this function always drops the first 1 sizes occurrences.
2333RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2334 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2335 ArrayRef<int64_t> sizes) {
2336 // Type inferred in the absence of rank-reducing behavior.
2337 auto inferredType = llvm::cast<RankedTensorType>(
2338 inferResultType(sourceRankedTensorType, sizes));
2339 int rankDiff = inferredType.getRank() - desiredResultRank;
2340 if (rankDiff > 0) {
2341 auto shape = inferredType.getShape();
2342 llvm::SmallBitVector dimsToProject =
2343 getPositionsOfShapeOne(rankDiff, shape);
2344 SmallVector<int64_t> projectedShape;
2345 // Best effort rank-reducing: drop 1s in order.
2346 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2347 if (!dimsToProject.test(pos))
2348 projectedShape.push_back(shape[pos]);
2349 inferredType =
2350 RankedTensorType::get(projectedShape, inferredType.getElementType());
2351 }
2352 return inferredType;
2353}
2354
2355RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2356 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2357 ArrayRef<OpFoldResult> sizes) {
2358 SmallVector<int64_t> staticSizes;
2359 SmallVector<Value> dynamicSizes;
2360 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2361 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2362 desiredResultRank, sourceRankedTensorType, staticSizes);
2363}
2364
2365/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2366/// result type. If the type passed is nullptr, it is inferred.
2367void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2368 RankedTensorType resultType, Value source,
2369 ArrayRef<OpFoldResult> offsets,
2370 ArrayRef<OpFoldResult> sizes,
2371 ArrayRef<OpFoldResult> strides,
2372 ArrayRef<NamedAttribute> attrs) {
2373 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2374 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2375 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2376 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2377 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2378 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2379 // Structuring implementation this way avoids duplication between builders.
2380 if (!resultType) {
2381 resultType = llvm::cast<RankedTensorType>(
2382 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2383 }
2384 result.addAttributes(attrs);
2385 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2386 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2387 b.getDenseI64ArrayAttr(staticSizes),
2388 b.getDenseI64ArrayAttr(staticStrides));
2389}
2390
2391/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2392/// result type.
2393void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2394 ArrayRef<OpFoldResult> offsets,
2395 ArrayRef<OpFoldResult> sizes,
2396 ArrayRef<OpFoldResult> strides,
2397 ArrayRef<NamedAttribute> attrs) {
2398 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2399}
2400
2401/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2402/// a Range vector.
2403void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2404 ArrayRef<Range> ranges,
2405 ArrayRef<NamedAttribute> attrs) {
2406 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2407 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2408}
2409
2410/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2411/// the type passed is nullptr, it is inferred.
2412void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2413 RankedTensorType resultType, Value source,
2414 ValueRange offsets, ValueRange sizes,
2415 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2416 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2417 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2418 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2419 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2420 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2421 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2422 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2423}
2424
2425/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2426void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2427 ValueRange offsets, ValueRange sizes,
2428 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2429 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2430}
2431
2433 Operation *op,
2434 RankedTensorType expectedType) {
2435 switch (result) {
2437 return success();
2439 return op->emitError("expected rank to be smaller or equal to ")
2440 << "the other rank. ";
2442 return op->emitError("expected type to be ")
2443 << expectedType << " or a rank-reduced version. (size mismatch) ";
2445 return op->emitError("expected element type to be ")
2446 << expectedType.getElementType();
2447 default:
2448 llvm_unreachable("unexpected extract_slice op verification result");
2449 }
2450}
2451
2452/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred
2453/// result type, offsets set to 0 and strides set to 1.
2454void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2455 RankedTensorType resultType, Value source,
2456 ArrayRef<OpFoldResult> sizes,
2457 ArrayRef<NamedAttribute> attrs) {
2458 Attribute zeroIdxAttr = b.getIndexAttr(0);
2459 Attribute oneIdxAttr = b.getIndexAttr(1);
2460 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2461 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2462 build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs);
2463}
2464
2465/// Verifier for ExtractSliceOp.
2466LogicalResult ExtractSliceOp::verify() {
2467 RankedTensorType sourceType = getSourceType();
2468
2469 // Verify result type against inferred type.
2470 RankedTensorType expectedType =
2471 ExtractSliceOp::inferResultType(sourceType, getMixedSizes());
2474 return produceSliceErrorMsg(result, *this, expectedType);
2475
2476 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2477 // to the source tensor.
2478 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2479 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2480 getStaticStrides(), /*generateErrorMessage=*/true);
2481 if (!boundsResult.isValid)
2482 return getOperation()->emitError(boundsResult.errorMessage);
2483
2484 return success();
2485}
2486
2487llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2488 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2489}
2490
2491FailureOr<Value>
2492ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2493 ArrayRef<int64_t> desiredShape) {
2494 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2495 assert(sourceTensorType && "not a ranked tensor type");
2496 auto sourceShape = sourceTensorType.getShape();
2497 if (sourceShape.equals(desiredShape))
2498 return value;
2499 auto maybeRankReductionMask =
2500 mlir::computeRankReductionMask(sourceShape, desiredShape);
2501 if (!maybeRankReductionMask)
2502 return failure();
2504 b, loc, value,
2505 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2506}
2507
2508LogicalResult ExtractSliceOp::reifyResultShapes(
2509 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2510 reifiedReturnShapes.resize(1);
2511 reifiedReturnShapes[0].reserve(getType().getRank());
2512 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2513 llvm::SmallBitVector droppedDims = getDroppedDims();
2514 for (const auto &size : enumerate(mixedSizes)) {
2515 if (droppedDims.test(size.index()))
2516 continue;
2517 reifiedReturnShapes[0].push_back(size.value());
2518 }
2519 return success();
2520}
2521
2522namespace {
2523/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2524/// This essentially pushes memref_cast past its consuming slice when
2525/// `canFoldIntoConsumerOp` is true.
2526///
2527/// Example:
2528/// ```
2529/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2530/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2531/// tensor<3x4xf32>
2532/// ```
2533/// is rewritten into:
2534/// ```
2535/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2536/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2537/// ```
2538class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2539public:
2540 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2541
2542 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2543 PatternRewriter &rewriter) const override {
2544 // Any constant operand, just return to let the constant folder kick in.
2545 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2546 return matchPattern(operand, matchConstantIndex());
2547 }))
2548 return failure();
2549
2550 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2551 if (!castOp)
2552 return failure();
2553
2554 if (!canFoldIntoConsumerOp(castOp))
2555 return failure();
2556
2557 // Pattern does not apply if the produced op would not verify.
2558 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2559 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2560 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2561 sliceOp.getStaticStrides());
2562 if (!sliceResult.isValid)
2563 return failure();
2564
2565 // Create folded extract.
2566 Location loc = sliceOp.getLoc();
2567 Value newResult = ExtractSliceOp::create(
2568 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2569 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2570 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2571 sliceOp.getStaticStrides());
2572 rewriter.replaceOp(sliceOp, newResult);
2573 return success();
2574 }
2575};
2576
2577/// Slice elements from `values` into `outValues`. `counts` represents the
2578/// numbers of elements to stride in the original values for each dimension.
2579/// The output values can be used to construct a DenseElementsAttr.
2580template <typename IterTy, typename ElemTy>
2581static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2582 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2583 ArrayRef<int64_t> strides,
2584 llvm::SmallVectorImpl<ElemTy> *outValues) {
2585 assert(offsets.size() == sizes.size());
2586 assert(offsets.size() == strides.size());
2587 if (offsets.empty())
2588 return;
2589
2590 int64_t offset = offsets.front();
2591 int64_t size = sizes.front();
2592 int64_t stride = strides.front();
2593 if (offsets.size() == 1) {
2594 for (int64_t i = 0; i < size; ++i, offset += stride)
2595 outValues->push_back(*(values + offset));
2596
2597 return;
2598 }
2599
2600 for (int64_t i = 0; i < size; ++i, offset += stride) {
2601 auto begin = values + offset * counts.front();
2602 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2603 offsets.drop_front(), sizes.drop_front(),
2604 strides.drop_front(), outValues);
2605 }
2606}
2607
2608/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2609/// folded operation might introduce more constant data; Users can control
2610/// their heuristics by the control function.
2611class ConstantOpExtractSliceFolder final
2612 : public OpRewritePattern<ExtractSliceOp> {
2613public:
2614 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2615
2616 ConstantOpExtractSliceFolder(MLIRContext *context,
2618 : OpRewritePattern<ExtractSliceOp>(context),
2619 controlFn(std::move(controlFn)) {}
2620
2621 LogicalResult matchAndRewrite(ExtractSliceOp op,
2622 PatternRewriter &rewriter) const override {
2623 DenseElementsAttr attr;
2624 if (!matchPattern(op.getSource(), m_Constant(&attr)))
2625 return failure();
2626
2627 // A constant splat is handled by fold().
2628 if (attr.isSplat())
2629 return failure();
2630
2631 // Dynamic result shape is not supported.
2632 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2633 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2634 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2635 return failure();
2636
2637 // Customized control over the folding.
2638 if (!controlFn(op))
2639 return failure();
2640
2641 int64_t count = sourceType.getNumElements();
2642 if (count == 0)
2643 return failure();
2644
2645 // Check if there are any dynamic parts, which are not supported.
2646 auto offsets = op.getStaticOffsets();
2647 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2648 return failure();
2649 auto sizes = op.getStaticSizes();
2650 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2651 return failure();
2652 auto strides = op.getStaticStrides();
2653 if (llvm::is_contained(strides, ShapedType::kDynamic))
2654 return failure();
2655
2656 // Compute the stride for each dimension.
2657 SmallVector<int64_t> counts;
2658 ArrayRef<int64_t> shape = sourceType.getShape();
2659 counts.reserve(shape.size());
2660 for (int64_t v : shape) {
2661 count = count / v;
2662 counts.push_back(count);
2663 }
2664
2665 // New attribute constructed by the sliced values.
2666 DenseElementsAttr newAttr;
2667
2668 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2669 SmallVector<APInt> outValues;
2670 outValues.reserve(sourceType.getNumElements());
2671 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2672 elems.begin(), counts, offsets, sizes, strides, &outValues);
2673 newAttr = DenseElementsAttr::get(resultType, outValues);
2674 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2675 SmallVector<APFloat> outValues;
2676 outValues.reserve(sourceType.getNumElements());
2677 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2678 elems.begin(), counts, offsets, sizes, strides, &outValues);
2679 newAttr = DenseElementsAttr::get(resultType, outValues);
2680 }
2681
2682 if (newAttr) {
2683 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2684 return success();
2685 }
2686
2687 return failure();
2688 }
2689
2690private:
2691 /// This additionally controls whether the fold happens or not. Users can
2692 /// impose their heuristics in the function.
2694};
2695
2696} // namespace
2697
2700 const ControlConstantExtractSliceFusionFn &controlFn) {
2701 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2702}
2703
2704/// Return the canonical type of the result of an extract_slice op.
2706 RankedTensorType operator()(ExtractSliceOp op,
2707 ArrayRef<OpFoldResult> mixedOffsets,
2708 ArrayRef<OpFoldResult> mixedSizes,
2709 ArrayRef<OpFoldResult> mixedStrides) {
2710 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2711 op.getType().getRank(), op.getSourceType(), mixedSizes);
2712 }
2713};
2714
2715/// A canonicalizer wrapper to replace ExtractSliceOps.
2717 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2718 ExtractSliceOp newOp) {
2719 Value replacement = newOp.getResult();
2720 if (replacement.getType() != op.getType())
2721 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2722 replacement);
2723 rewriter.replaceOp(op, replacement);
2724 }
2725};
2726
2727void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2728 MLIRContext *context) {
2729 results.add<
2730 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2731 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2732 ExtractSliceOpCastFolder>(context);
2733}
2734
2735//
2736static LogicalResult
2737foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2738 ShapedType shapedType) {
2739 OpBuilder b(op.getContext());
2740 for (OpFoldResult ofr : op.getMixedOffsets())
2741 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2742 return failure();
2743 // Rank-reducing noops only need to inspect the leading dimensions:
2744 // llvm::zip is appropriate.
2745 auto shape = shapedType.getShape();
2746 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2747 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2748 return failure();
2749 for (OpFoldResult ofr : op.getMixedStrides())
2750 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2751 return failure();
2752 return success();
2753}
2754
2755/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2756/// slice, we can return the InsertSliceOp's source directly.
2757// TODO: This only checks the immediate producer; extend to go up the
2758// insert/extract chain if the slices are disjoint.
2759static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2760 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2761
2762 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2763 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2764 insertOp.isSameAs(extractOp, isSame))
2765 return insertOp.getSource();
2766
2767 return {};
2768}
2769
2770OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2771 if (OpFoldResult reshapedSource = reshapeConstantSource(
2772 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2773 getResult().getType()))
2774 return reshapedSource;
2775 if (getSourceType() == getType() &&
2777 return this->getSource();
2778 if (Value slice = foldExtractAfterInsertSlice(*this))
2779 return slice;
2780
2781 return OpFoldResult();
2782}
2783
2785 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2786 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2787 unsigned rank = rankedTensorType.getRank();
2788 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2790 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2791 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2792 offsets, sizes, strides);
2793}
2794
2795//===----------------------------------------------------------------------===//
2796// InsertSliceOp
2797//===----------------------------------------------------------------------===//
2798
2799void InsertSliceOp::getAsmResultNames(
2800 function_ref<void(Value, StringRef)> setNameFn) {
2801 setNameFn(getResult(), "inserted_slice");
2802}
2803
2804// Build a InsertSliceOp with mixed static and dynamic entries.
2805void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2806 Value dest, ArrayRef<OpFoldResult> offsets,
2808 ArrayRef<OpFoldResult> strides,
2810 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2811 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2812 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2813 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2814 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2815 result.addAttributes(attrs);
2816 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2817 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2818 b.getDenseI64ArrayAttr(staticSizes),
2819 b.getDenseI64ArrayAttr(staticStrides));
2820}
2821
2822/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2823/// Range vector.
2824void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2825 Value dest, ArrayRef<Range> ranges,
2826 ArrayRef<NamedAttribute> attrs) {
2827 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2828 build(b, result, source, dest, offsets, sizes, strides, attrs);
2829}
2830
2831// Build a InsertSliceOp with dynamic entries.
2832void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2833 Value dest, ValueRange offsets, ValueRange sizes,
2834 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2835 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2836 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2837 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2838 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2839 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2840 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2841 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2842}
2843
2844/// Rank-reducing type verification for both InsertSliceOp and
2845/// ParallelInsertSliceOp.
2847 RankedTensorType srcType, RankedTensorType dstType,
2848 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2849 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2850 // insert_slice is the inverse of extract_slice, use the same type
2851 // inference.
2852 RankedTensorType expected =
2853 ExtractSliceOp::inferResultType(dstType, staticSizes);
2854 if (expectedType)
2855 *expectedType = expected;
2856 return isRankReducedType(expected, srcType);
2857}
2858
2859/// Verifier for InsertSliceOp.
2860LogicalResult InsertSliceOp::verify() {
2861 // Verify result type against inferred type.
2862 RankedTensorType expectedType;
2864 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2865 getStaticSizes(), getStaticStrides(), &expectedType);
2867 return produceSliceErrorMsg(result, *this, expectedType);
2868
2869 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2870 // to the destination tensor.
2871 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2872 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2873 getStaticStrides(), /*generateErrorMessage=*/true);
2874 if (!boundsResult.isValid)
2875 return getOperation()->emitError(boundsResult.errorMessage);
2876
2877 return success();
2878}
2879
2880/// If we have two consecutive InsertSliceOp writing to the same slice, we
2881/// can mutate the second InsertSliceOp's destination to the first one's.
2882///
2883/// Example:
2884///
2885/// ```mlir
2886/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2887/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2888/// ```
2889///
2890/// folds into:
2891///
2892/// ```mlir
2893/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2894/// ```
2895///
2896/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2897static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2898 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2899
2900 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2901 if (!prevInsertOp ||
2902 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2903 !prevInsertOp.isSameAs(insertOp, isSame))
2904 return failure();
2905
2906 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2907 return success();
2908}
2909
2910/// Folds round-trip extract/insert slice op pairs.
2911/// Example:
2912/// ```mlir
2913/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2914/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2915/// ```
2916/// can be folded into %val.
2917static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2918 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2919
2920 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2921 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2922 !extractOp.isSameAs(insertOp, isSame))
2923 return nullptr;
2924
2925 return extractOp.getSource();
2926}
2927
2928OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2929 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2930 getSourceType() == getType() &&
2932 return this->getSource();
2933 if (succeeded(foldInsertAfterInsertSlice(*this)))
2934 return getResult();
2935 if (auto result = foldInsertAfterExtractSlice(*this))
2936 return result;
2937 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2938 return getDest();
2939 return OpFoldResult();
2940}
2941
2942LogicalResult InsertSliceOp::reifyResultShapes(
2943 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2944 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2945 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2946 return success();
2947}
2948
2949namespace {
2950/// Pattern to rewrite a insert_slice op with constant arguments.
2951///
2952/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2953template <typename InsertOpTy>
2954class InsertSliceOpConstantArgumentFolder final
2955 : public OpRewritePattern<InsertOpTy> {
2956public:
2957 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2958
2959 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2960 PatternRewriter &rewriter) const override {
2961 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2962 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2963 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2964
2965 // No constant operands were folded, just return;
2966 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2967 failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2968 failed(foldDynamicStrideList(mixedStrides)))
2969 return failure();
2970
2971 // Pattern does not apply if the produced op would not verify.
2972 SliceBoundsVerificationResult sliceResult =
2973 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2974 mixedOffsets, mixedSizes, mixedStrides);
2975 if (!sliceResult.isValid)
2976 return failure();
2977
2978 // Create the new op in canonical form.
2979 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2980 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2981 mixedSizes);
2982 Value toInsert = insertSliceOp.getSource();
2983 if (sourceType != insertSliceOp.getSourceType()) {
2984 OpBuilder::InsertionGuard g(rewriter);
2985 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2986 // is that the insertion point is just before the InParallelOp in
2987 // the parallel case.
2988 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2989 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2990 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2991 sourceType, toInsert);
2992 }
2993 rewriter.replaceOpWithNewOp<InsertOpTy>(
2994 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2995 mixedSizes, mixedStrides);
2996 return success();
2997 }
2998};
2999
3000/// Fold tensor_casts with insert_slice operations. If the source or
3001/// destination tensor is a tensor_cast that removes static type information,
3002/// the cast is folded into the insert_slice operation. E.g.:
3003///
3004/// ```mlir
3005/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3006/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3007/// ```
3008///
3009/// folds into:
3010///
3011/// ```mlir
3012/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3013/// ```
3014///
3015/// Note: When folding a cast on the destination tensor, the result of the
3016/// insert_slice operation is casted to ensure that the type of the result did
3017/// not change.
3018///
3019/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3020template <typename InsertOpTy>
3021struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3022 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3023
3024 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3025 PatternRewriter &rewriter) const override {
3026 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3027 return matchPattern(operand, matchConstantIndex());
3028 }))
3029 return failure();
3030
3031 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3032 auto castOp = v.getDefiningOp<tensor::CastOp>();
3033 if (!castOp || !canFoldIntoConsumerOp(castOp))
3034 return std::nullopt;
3035 return castOp.getSource();
3036 };
3037 std::optional<Value> sourceCastSource =
3038 getSourceOfCastOp(insertSliceOp.getSource());
3039 std::optional<Value> destCastSource =
3040 getSourceOfCastOp(insertSliceOp.getDest());
3041 if (!sourceCastSource && !destCastSource)
3042 return failure();
3043
3044 auto src =
3045 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3046 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3047 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3048 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3049 if (!srcType || !dstType)
3050 return failure();
3051
3052 // The tensor.cast source could have additional static information not seen
3053 // in the insert slice op static sizes, so we ignore dynamic dims when
3054 // computing the rank reduction mask.
3055 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3056 auto rankReductionMask = computeRankReductionMask(
3057 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3058 if (!rankReductionMask.has_value())
3059 return failure();
3060 // Replace dimensions in the insert slice op with corresponding static dims
3061 // from the cast source type. If the insert slice sizes have static dims
3062 // that are not static in the tensor.cast source (i.e., when the cast op
3063 // casts a dynamic dim to static), the dim should not be replaced, and the
3064 // pattern will fail later in `verifyInsertSliceOp`.
3065 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3066 int64_t rankReducedIdx = 0;
3067 for (auto [idx, size] : enumerate(staticSizes)) {
3068 if (!rankReductionMask.value().contains(idx) &&
3069 !srcType.isDynamicDim(rankReducedIdx)) {
3070 mixedSizes[idx] = getAsIndexOpFoldResult(
3071 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3072 size = srcType.getDimSize(rankReducedIdx++);
3073 }
3074 }
3075
3076 // Pattern does not apply if the produced op would not verify.
3077 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3078 staticSizes, insertSliceOp.getStaticStrides()) !=
3079 SliceVerificationResult::Success)
3080 return failure();
3081 SliceBoundsVerificationResult sliceResult =
3082 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3083 mixedSizes, insertSliceOp.getMixedStrides());
3084 if (!sliceResult.isValid)
3085 return failure();
3086
3087 Operation *replacement =
3088 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3089 insertSliceOp.getMixedOffsets(), mixedSizes,
3090 insertSliceOp.getMixedStrides());
3091
3092 // In the parallel case there is no result and so nothing to cast.
3093 bool isParallelInsert =
3094 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3095 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3096 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3097 insertSliceOp.getDestType(),
3098 replacement->getResult(0));
3099 }
3100 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3101 return success();
3102 }
3103};
3104
3105/// If additional static type information can be deduced from a insert_slice's
3106/// size operands, insert an explicit cast of the op's source operand. This
3107/// enables other canonicalization patterns that are matching for tensor_cast
3108/// ops such as `ForOpTensorCastFolder` in SCF.
3109///
3110/// Example:
3111///
3112/// ```mlir
3113/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3114/// : tensor<?x?xf32> into ...
3115/// ```
3116///
3117/// folds into:
3118///
3119/// ```mlir
3120/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3121/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3122/// : tensor<64x64xf32> into ...
3123/// ```
3124///
3125/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3126template <typename InsertOpTy>
3127struct InsertSliceOpSourceCastInserter final
3128 : public OpRewritePattern<InsertOpTy> {
3129 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3130
3131 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3132 PatternRewriter &rewriter) const override {
3133 RankedTensorType srcType = insertSliceOp.getSourceType();
3134 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3135 return failure();
3136 SmallVector<int64_t> newSrcShape(srcType.getShape());
3137 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3138 if (std::optional<int64_t> constInt =
3139 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3140 // Bail on invalid IR.
3141 if (*constInt < 0)
3142 return failure();
3143 newSrcShape[i] = *constInt;
3144 }
3145 }
3146 if (!hasValidSizesOffsets(newSrcShape))
3147 return failure();
3148
3149 RankedTensorType newSrcType = RankedTensorType::get(
3150 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3151 if (srcType == newSrcType ||
3152 !preservesStaticInformation(srcType, newSrcType) ||
3153 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3154 return failure();
3155
3156 // newSrcType is:
3157 // 1) Different from srcType.
3158 // 2) "More static" than srcType.
3159 // 3) Cast-compatible with srcType.
3160 // Insert the cast.
3161 OpBuilder::InsertionGuard g(rewriter);
3162 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3163 // that the insertion point is just before the InParallelOp in the
3164 // parallel case.
3165 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3166 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3167 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3168 newSrcType, insertSliceOp.getSource());
3169 rewriter.replaceOpWithNewOp<InsertOpTy>(
3170 insertSliceOp, cast, insertSliceOp.getDest(),
3171 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3172 insertSliceOp.getMixedStrides());
3173 return success();
3174 }
3175};
3176} // namespace
3177
3178llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3179 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3180}
3181
3182void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3183 MLIRContext *context) {
3184 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3185 InsertSliceOpCastFolder<InsertSliceOp>,
3186 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3187}
3188
3190 Location loc,
3191 Value tensor,
3192 Value dest) {
3193 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3194 unsigned rank = rankedTensorType.getRank();
3195 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3196 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3197 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3198 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3199 sizes, strides);
3200}
3201
3202//===----------------------------------------------------------------------===//
3203// PadOp
3204//===----------------------------------------------------------------------===//
3205
3206void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3207 setNameFn(getResult(), "padded");
3208}
3209
3210LogicalResult PadOp::verify() {
3211 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3212 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3213 auto expectedType =
3214 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3215 if (!expectedType) {
3216 return emitError("failed to infer expectedType from sourceType ")
3217 << sourceType << ", specified resultType is " << resultType;
3218 }
3219 if (resultType.getRank() != expectedType.getRank()) {
3220 return emitError("specified type ")
3221 << resultType << " does not match the inferred type "
3222 << expectedType;
3223 }
3224 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3225 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3226 continue;
3227 if (expectedType.isDynamicDim(i))
3228 continue;
3229 return emitError("specified type ")
3230 << resultType << " does not match the inferred type "
3231 << expectedType;
3232 }
3233
3234 return success();
3235}
3236
3237LogicalResult PadOp::verifyRegions() {
3238 auto &region = getRegion();
3239 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3240 Block &block = region.front();
3241 if (block.getNumArguments() != rank)
3242 return emitError("expected the block to have ") << rank << " arguments";
3243
3244 // Note: the number and type of yield values are checked in the YieldOp.
3245 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3246 if (!en.value().isIndex())
3247 return emitOpError("expected block argument ")
3248 << (en.index() + 1) << " to be an index";
3249 }
3250
3251 // Ensure that the region yields an element of the right type.
3252 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3253 if (yieldOp.getValue().getType() !=
3254 llvm::cast<ShapedType>(getType()).getElementType())
3255 return emitOpError("expected yield type to match shape element type");
3256
3257 return success();
3258}
3259
3260RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3261 ArrayRef<int64_t> staticLow,
3262 ArrayRef<int64_t> staticHigh,
3263 ArrayRef<int64_t> resultShape) {
3264 unsigned rank = sourceType.getRank();
3265 if (staticLow.size() != rank)
3266 return RankedTensorType();
3267 if (staticHigh.size() != rank)
3268 return RankedTensorType();
3269 if (!resultShape.empty() && resultShape.size() != rank)
3270 return RankedTensorType();
3271
3272 SmallVector<int64_t, 4> inferredShape;
3273 for (auto i : llvm::seq<unsigned>(0, rank)) {
3274 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3275 staticHigh[i] == ShapedType::kDynamic) {
3276 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3277 : resultShape[i]);
3278 } else {
3279 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3280 assert((resultShape.empty() || size == resultShape[i] ||
3281 resultShape[i] == ShapedType::kDynamic) &&
3282 "mismatch between inferred shape and result shape");
3283 inferredShape.push_back(size);
3284 }
3285 }
3286
3287 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3288}
3289
3290void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3291 Value source, ArrayRef<int64_t> staticLow,
3292 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3293 bool nofold, ArrayRef<NamedAttribute> attrs) {
3294 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3295 if (!resultType)
3296 resultType = inferResultType(sourceType, staticLow, staticHigh);
3297 result.addAttributes(attrs);
3298 build(b, result, resultType, source, low, high,
3299 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3300 nofold ? b.getUnitAttr() : UnitAttr());
3301}
3302
3303void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3304 Value source, ValueRange low, ValueRange high, bool nofold,
3305 ArrayRef<NamedAttribute> attrs) {
3306 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3307 unsigned rank = sourceType.getRank();
3308 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3309 build(b, result, resultType, source, staticVector, staticVector, low, high,
3310 nofold, attrs);
3311}
3312
3313void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3314 Value source, ArrayRef<OpFoldResult> low,
3315 ArrayRef<OpFoldResult> high, bool nofold,
3316 ArrayRef<NamedAttribute> attrs) {
3317 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3318 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3319 SmallVector<int64_t, 4> staticLow, staticHigh;
3320 // staticLow and staticHigh have full information of the padding config.
3321 // This will grow staticLow and staticHigh with 1 value. If the config is
3322 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3323 // value as well.
3324 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3325 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3326 if (!resultType) {
3327 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3328 }
3329 assert(llvm::isa<RankedTensorType>(resultType));
3330 result.addAttributes(attrs);
3331 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3332 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3333 nofold ? b.getUnitAttr() : UnitAttr());
3334}
3335
3336void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3337 Value source, ArrayRef<OpFoldResult> low,
3338 ArrayRef<OpFoldResult> high, Value constantPadValue,
3339 bool nofold, ArrayRef<NamedAttribute> attrs) {
3340 build(b, result, resultType, source, low, high, nofold, attrs);
3341
3342 // Add a region and a block to yield the pad value.
3343 Region *region = result.regions[0].get();
3344 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3345 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3346 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3347
3348 // `builder.createBlock` changes the insertion point within the block. Create
3349 // a guard to reset the insertion point of the builder after it is destroyed.
3350 OpBuilder::InsertionGuard guard(b);
3351 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3352 tensor::YieldOp::create(b, result.location, constantPadValue);
3353}
3354
3355llvm::SmallBitVector PadOp::getPaddedDims() {
3356 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3357 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3358 for (const auto &en : enumerate(paddingWidths))
3359 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3360 paddedDims.set(en.index());
3361 };
3362 extractPaddedDims(getMixedLowPad());
3363 extractPaddedDims(getMixedHighPad());
3364 return paddedDims;
3365}
3366
3367namespace {
3368// Folds tensor.pad when padding is static zeros and the attribute
3369// doesn't request otherwise.
3370struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3371 using OpRewritePattern<PadOp>::OpRewritePattern;
3372
3373 LogicalResult matchAndRewrite(PadOp padTensorOp,
3374 PatternRewriter &rewriter) const override {
3375 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3376 return failure();
3377 if (padTensorOp.getNofold())
3378 return failure();
3379 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3380 padTensorOp, padTensorOp.getResult().getType(),
3381 padTensorOp.getSource());
3382 return success();
3383 }
3384};
3385
3386// Fold CastOp into PadOp when adding static information.
3387struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3388 using OpRewritePattern<PadOp>::OpRewritePattern;
3389
3390 LogicalResult matchAndRewrite(PadOp padTensorOp,
3391 PatternRewriter &rewriter) const override {
3392 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3393 if (!tensor::canFoldIntoConsumerOp(castOp))
3394 return failure();
3395
3396 auto newResultType = PadOp::inferResultType(
3397 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3398 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3399 padTensorOp.getResultType().getShape());
3400
3401 if (newResultType == padTensorOp.getResultType()) {
3402 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3403 padTensorOp.getSourceMutable().assign(castOp.getSource());
3404 });
3405 } else {
3406 auto newOp = PadOp::create(
3407 rewriter, padTensorOp->getLoc(), newResultType,
3408 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3409 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3410 padTensorOp.getHigh(), padTensorOp.getNofold(),
3411 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3412 IRMapping mapper;
3413 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3414
3415 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3416 padTensorOp, padTensorOp.getResultType(), newOp);
3417 }
3418 return success();
3419 }
3420};
3421
3422// Fold CastOp using the result of PadOp back into the latter if it adds
3423// static information.
3424struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3425 using OpRewritePattern<PadOp>::OpRewritePattern;
3426
3427 LogicalResult matchAndRewrite(PadOp padTensorOp,
3428 PatternRewriter &rewriter) const override {
3429 if (!padTensorOp.getResult().hasOneUse())
3430 return failure();
3431 auto tensorCastOp =
3432 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3433 if (!tensorCastOp)
3434 return failure();
3435 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3436 tensorCastOp.getDest().getType()))
3437 return failure();
3438
3439 auto replacementOp = PadOp::create(
3440 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3441 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3442 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3443 padTensorOp.getHigh(), padTensorOp.getNofold(),
3444 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3445 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3446
3447 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3448 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3449 return success();
3450 }
3451};
3452
3453/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3454/// different dimensions. The pattern applies if the following preconditions
3455/// hold:
3456/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3457/// 2) the tensor::ExtractSliceOps have only unit-strides,
3458/// 3) the tensor::PadOps perform only high-padding,
3459/// 4) the tensor::PadOps have the same constant padding value,
3460/// 5) the tensor::PadOps do not have common padding dimensions,
3461/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3462/// zero-offset for every dimension.
3463/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3464/// the
3465/// padded source dimensions.
3466///
3467/// Example:
3468///
3469/// ```mlir
3470/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3471/// : tensor<64x64xf32> to tensor<?x64xf32>
3472/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3473/// } : tensor<?x64xf32> to tensor<8x64xf32>
3474/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3475/// : tensor<8x64xf32> to tensor<8x?xf32>
3476/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3477/// } : tensor<8x?xf32> to tensor<8x4xf32>
3478/// ```
3479///
3480/// folds into:
3481///
3482/// ```mlir
3483/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3484/// : tensor<64x64xf32> to tensor<?x?xf32>
3485/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3486/// } : tensor<?x?xf32> to tensor<8x4xf32>
3487/// ```
3488struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3489 using OpRewritePattern<PadOp>::OpRewritePattern;
3490
3491 LogicalResult matchAndRewrite(PadOp padOp,
3492 PatternRewriter &rewriter) const override {
3493 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3494 if (!innerSliceOp)
3495 return failure();
3496 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3497 if (!outerPadOp || outerPadOp.getNofold())
3498 return failure();
3499 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3500 if (!outerSliceOp)
3501 return failure();
3502
3503 // 1) Fail if the chain is rank-reducing.
3504 int64_t rank = padOp.getSourceType().getRank();
3505 if (outerSliceOp.getSourceType().getRank() != rank) {
3506 return rewriter.notifyMatchFailure(padOp,
3507 "cannot fold rank-reducing chain");
3508 }
3509
3510 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3511 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3512 return rewriter.notifyMatchFailure(
3513 padOp, "cannot fold non-unit stride ExtractSliceOps");
3514 }
3515
3516 // 3) Fail if the tensor::PadOps have non-zero low padding.
3517 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3518 return rewriter.notifyMatchFailure(padOp,
3519 "cannot fold PadOps with low padding");
3520 }
3521
3522 // 4) Fail if the tensor::PadOps padding values do not match.
3523 Attribute innerAttr, outerAttr;
3524 Value innerValue = padOp.getConstantPaddingValue();
3525 Value outerValue = outerPadOp.getConstantPaddingValue();
3526 if (!innerValue || !outerValue ||
3527 !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3528 !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3529 innerAttr != outerAttr) {
3530 return rewriter.notifyMatchFailure(
3531 padOp, "cannot fold PadOps with different padding values");
3532 }
3533
3534 // 5) Fail if a dimension is padded by both tensor::PadOps.
3535 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3536 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3537 if (innerDims.anyCommon(outerDims)) {
3538 return rewriter.notifyMatchFailure(
3539 padOp, "cannot fold PadOps with common padding dimensions");
3540 }
3541
3542 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3543 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3544 // for every dimension, and use the offset the other pair. Fail if no
3545 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3546 // exists.
3547 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3548 for (auto en : enumerate(newOffsets)) {
3549 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3550 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3551 if (!innerDims.test(en.index()) &&
3552 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3553 en.value() = outerOffset;
3554 continue;
3555 }
3556 if (!outerDims.test(en.index()) &&
3557 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3558 en.value() = innerOffset;
3559 continue;
3560 }
3561 return rewriter.notifyMatchFailure(
3562 padOp, "cannot find zero-offset and zero-padding pair");
3563 }
3564
3565 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3566 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3567 // outer tensor::PadOp and fail if the size of the inner
3568 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3569 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3570 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3571 for (auto en : enumerate(newSizes)) {
3572 if (!outerDims.test(en.index()))
3573 continue;
3574 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3575 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3576 assert(ShapedType::isStatic(sourceSize) &&
3577 "expected padded dimension to have a static size");
3578 if (getConstantIntValue(sliceSize) != sourceSize) {
3579 return rewriter.notifyMatchFailure(
3580 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3581 "match the size of the outer padding");
3582 }
3583 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3584 }
3585
3586 // Combine the high paddings of the two tensor::PadOps.
3587 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3588 for (auto en : enumerate(newHighPad)) {
3589 if (innerDims.test(en.index()))
3590 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3591 if (outerDims.test(en.index()))
3592 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3593 }
3594
3595 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3596 // the two paddings in one step.
3597 auto newSliceOp = ExtractSliceOp::create(
3598 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3599 newSizes, innerSliceOp.getMixedStrides());
3600 auto newPadOp = PadOp::create(
3601 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3602 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3603 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3604 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3605 newPadOp.getRegion().begin());
3606 rewriter.replaceOp(padOp, newPadOp.getResult());
3607 return success();
3608 }
3609};
3610
3611struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3612 using OpRewritePattern<PadOp>::OpRewritePattern;
3613
3614 LogicalResult matchAndRewrite(PadOp padTensorOp,
3615 PatternRewriter &rewriter) const override {
3616 Value input = padTensorOp.getSource();
3617 if (!llvm::isa<RankedTensorType>(input.getType()))
3618 return failure();
3619 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3620 auto inputRank = inputDims.size();
3621
3622 auto oldResultType =
3623 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3624 if (!oldResultType)
3625 return failure();
3626
3627 auto outputDims = oldResultType.getShape();
3628
3629 // Extract the static info from the high and low operands.
3630 SmallVector<int64_t> constOperandsLow;
3631 SmallVector<Value> newLows;
3632 for (auto operand : padTensorOp.getLow()) {
3633 APSInt intOp;
3634 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3635 constOperandsLow.push_back(ShapedType::kDynamic);
3636 newLows.push_back(operand);
3637 continue;
3638 }
3639 constOperandsLow.push_back(intOp.getExtValue());
3640 }
3641 SmallVector<int64_t> constOperandsHigh;
3642 SmallVector<Value> newHighs;
3643 for (auto operand : padTensorOp.getHigh()) {
3644 APSInt intOp;
3645 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3646 constOperandsHigh.push_back(ShapedType::kDynamic);
3647 newHighs.push_back(operand);
3648 continue;
3649 }
3650 constOperandsHigh.push_back(intOp.getExtValue());
3651 }
3652
3653 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3654 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3655
3656 // Verify the op is well-formed.
3657 if (inputDims.size() != outputDims.size() ||
3658 inputDims.size() != constLow.size() ||
3659 inputDims.size() != constHigh.size())
3660 return failure();
3661
3662 auto lowCount = 0;
3663 auto highCount = 0;
3664 for (size_t i = 0; i < inputRank; i++) {
3665 if (constLow[i] == ShapedType::kDynamic)
3666 constLow[i] = constOperandsLow[lowCount++];
3667 if (constHigh[i] == ShapedType::kDynamic)
3668 constHigh[i] = constOperandsHigh[highCount++];
3669 }
3670
3671 auto staticLow = ArrayRef<int64_t>(constLow);
3672 auto staticHigh = ArrayRef<int64_t>(constHigh);
3673
3674 // Calculate the output sizes with the static information.
3675 SmallVector<int64_t> newOutDims;
3676 for (size_t i = 0; i < inputRank; i++) {
3677 if (outputDims[i] == ShapedType::kDynamic) {
3678 newOutDims.push_back(
3679 (staticLow[i] == ShapedType::kDynamic ||
3680 staticHigh[i] == ShapedType::kDynamic ||
3681 inputDims[i] == ShapedType::kDynamic
3682 ? ShapedType::kDynamic
3683 : inputDims[i] + staticLow[i] + staticHigh[i]));
3684 } else {
3685 newOutDims.push_back(outputDims[i]);
3686 }
3687 }
3688
3689 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3690 llvm::all_of(newOutDims,
3691 [&](int64_t x) { return x == ShapedType::kDynamic; }))
3692 return failure();
3693
3694 // Rewrite the op using the new static type.
3695 auto newResultType = RankedTensorType::get(
3696 newOutDims, padTensorOp.getType().getElementType());
3697 auto newOp = PadOp::create(
3698 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3699 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3700 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3701
3702 IRMapping mapper;
3703 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3704 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3705 newOp);
3706
3707 return success();
3708 }
3709};
3710
3711/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3712///
3713/// Example:
3714///
3715/// ```mlir
3716/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3717/// tensor.yield %val
3718/// } : tensor<1x2xf32> to tensor<2x5xf32>
3719/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3720/// tensor.yield %val
3721/// } : tensor<1x5xf32> to tensor<5x7xf32>
3722/// ```
3723///
3724/// folds into:
3725///
3726/// ```mlir
3727/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3728/// tensor.yield %val
3729/// } : tensor<1x2xf32> to tensor<5x7xf32>
3730/// ```
3731struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3732 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3733
3734 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3735 PatternRewriter &rewriter) const override {
3736 if (padOp.getNofold()) {
3737 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3738 }
3739
3740 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3741 if (!producerPad || producerPad.getNofold()) {
3742 return rewriter.notifyMatchFailure(
3743 padOp, "producer is not a foldable tensor.pad op");
3744 }
3745
3746 // Fail if the tensor::PadOps padding values do not match.
3747 Value consumerPadValue = padOp.getConstantPaddingValue();
3748 Value producerPadValue = producerPad.getConstantPaddingValue();
3749 if (!consumerPadValue || !producerPadValue ||
3750 consumerPadValue != producerPadValue) {
3751 return rewriter.notifyMatchFailure(
3752 padOp,
3753 "cannot fold PadOps with different or non-constant padding values");
3754 }
3755
3756 Location loc = padOp.getLoc();
3757 AffineExpr d0, d1;
3758 bindDims(rewriter.getContext(), d0, d1);
3759
3760 // Combine the low/high paddings of the two tensor::PadOps.
3761 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3762 ArrayRef<OpFoldResult> producerPaddings) {
3763 SmallVector<OpFoldResult> sumPaddings;
3764 for (auto [consumerIndex, producerIndex] :
3765 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3766 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3767 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3768 }
3769 return sumPaddings;
3770 };
3771
3772 SmallVector<OpFoldResult> newHighPad =
3773 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3774 SmallVector<OpFoldResult> newLowPad =
3775 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3776
3777 auto newPadOp = tensor::PadOp::create(
3778 rewriter, padOp.getLoc(), padOp.getResultType(),
3779 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3780 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3781 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3782 newPadOp.getRegion().begin());
3783 rewriter.replaceOp(padOp, newPadOp.getResult());
3784 return success();
3785 }
3786};
3787
3788} // namespace
3789
3790LogicalResult
3791PadOp::reifyResultShapes(OpBuilder &b,
3792 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3793 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3794 SmallVector<OpFoldResult> lp = getMixedLowPad();
3795 SmallVector<OpFoldResult> hp = getMixedHighPad();
3796 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3797 if (!getType().isDynamicDim(i)) {
3798 reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3799 continue;
3800 }
3801 Location loc = getLoc();
3802 Value dim = b.createOrFold<tensor::DimOp>(
3803 loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3804
3805 AffineExpr d0, d1, d2;
3806 bindDims(b.getContext(), d0, d1, d2);
3807 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3808 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3809 }
3810 return success();
3811}
3812
3813void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3814 MLIRContext *context) {
3815 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3816 FoldOrthogonalPaddings, FoldStaticPadding,
3817 FoldConsecutiveConstantPadding>(context);
3818}
3819
3820/// Return the padding value of the PadOp if it constant. In this context,
3821/// "constant" means an actual constant or "defined outside of the block".
3822///
3823/// Values are considered constant in three cases:
3824/// - A ConstantLike value.
3825/// - A basic block argument from a different block.
3826/// - A value defined outside of the block.
3827///
3828/// If the padding value is not constant, an empty Value is returned.
3829Value PadOp::getConstantPaddingValue() {
3830 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3831 if (!yieldOp)
3832 return {};
3833 Value padValue = yieldOp.getValue();
3834 // Check if yield value is a constant.
3835 if (matchPattern(padValue, m_Constant()))
3836 return padValue;
3837 // Check if yield value is defined inside the PadOp block.
3838 if (padValue.getParentBlock() == &getRegion().front())
3839 return {};
3840 // Else: Yield value defined outside of the PadOp block.
3841 return padValue;
3842}
3843
3844OpFoldResult PadOp::fold(FoldAdaptor) {
3845 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3846 !getNofold())
3847 return getSource();
3848 return {};
3849}
3850
3851//===----------------------------------------------------------------------===//
3852// ParallelInsertSliceOp
3853//===----------------------------------------------------------------------===//
3854
3855OpResult ParallelInsertSliceOp::getTiedOpResult() {
3856 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3857 for (const auto &it :
3858 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3859 Operation &nextOp = it.value();
3860 if (&nextOp == getOperation())
3861 return parallelCombiningParent.getParentResult(it.index());
3862 }
3863 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3864}
3865
3866// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3867void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3868 Value source, Value dest,
3869 ArrayRef<OpFoldResult> offsets,
3870 ArrayRef<OpFoldResult> sizes,
3871 ArrayRef<OpFoldResult> strides,
3872 ArrayRef<NamedAttribute> attrs) {
3873 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3874 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3875 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3876 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3877 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3878 result.addAttributes(attrs);
3879 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3880 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3881 b.getDenseI64ArrayAttr(staticSizes),
3882 b.getDenseI64ArrayAttr(staticStrides));
3883}
3884
3885/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3886/// packed into a Range vector.
3887void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3888 Value source, Value dest,
3889 ArrayRef<Range> ranges,
3890 ArrayRef<NamedAttribute> attrs) {
3891 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3892 build(b, result, source, dest, offsets, sizes, strides, attrs);
3893}
3894
3895// Build a ParallelInsertSliceOp with dynamic entries.
3896void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3897 Value source, Value dest, ValueRange offsets,
3898 ValueRange sizes, ValueRange strides,
3899 ArrayRef<NamedAttribute> attrs) {
3900 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3901 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3902 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3903 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3904 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3905 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3906 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3907}
3908
3909// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set
3910// to 0, strides set to 1 and inferred result type.
3911void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
3912 Value dest, ArrayRef<OpFoldResult> sizes,
3913 ArrayRef<NamedAttribute> attrs) {
3914 Attribute zeroIdxAttr = b.getIndexAttr(0);
3915 Attribute oneIdxAttr = b.getIndexAttr(1);
3916 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3917 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3918 build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3919}
3920
3921LogicalResult ParallelInsertSliceOp::verify() {
3922 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3923 return this->emitError("expected InParallelOpInterface parent, got:")
3924 << *(getOperation()->getParentOp());
3925
3926 // Verify result type against inferred type.
3927 RankedTensorType expectedType;
3929 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3930 getStaticSizes(), getStaticStrides(), &expectedType);
3932 return produceSliceErrorMsg(result, *this, expectedType);
3933
3934 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3935 // to the destination tensor.
3936 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3937 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3938 getStaticStrides(), /*generateErrorMessage=*/true);
3939 if (!boundsResult.isValid)
3940 return getOperation()->emitError(boundsResult.errorMessage);
3941
3942 return success();
3943}
3944
3945void ParallelInsertSliceOp::getCanonicalizationPatterns(
3946 RewritePatternSet &results, MLIRContext *context) {
3947 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3948 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3949 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3950}
3951
3952llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3953 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3954}
3955
3956// ParallelCombiningOpInterface implementation.
3957MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3958 return getDestMutable();
3959}
3960
3961Operation *ParallelInsertSliceOp::getIteratingParent() {
3962 // Return the parent InParallelOpInterface's parent.
3963 if (auto combiningOp =
3964 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3965 return combiningOp->getParentOp();
3966 return nullptr;
3967}
3968
3969//===----------------------------------------------------------------------===//
3970// ScatterOp
3971//===----------------------------------------------------------------------===//
3972
3973void ScatterOp::getAsmResultNames(
3974 function_ref<void(Value, StringRef)> setNameFn) {
3975 setNameFn(getResult(), "scatter");
3976}
3977
3978LogicalResult ScatterOp::verify() {
3979 int64_t destRank = getDestType().getRank();
3980 ArrayRef<int64_t> scatterDims = getScatterDims();
3981 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3982 getIndicesType().getShape(), destRank,
3983 "scatter", "dest")))
3984 return failure();
3985
3986 if (!getUnique())
3987 return emitOpError("requires 'unique' attribute to be set");
3988 // TODO: we could also check statically that there are fewer leading index
3989 // tensor dims than the dest dims. If this is not the case, the unique
3990 // attribute cannot be true.
3991
3992 // Use the GatherOp::inferResultType on the `dest` type and verify the
3993 // expected type matches the source type.
3994 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3995 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3996 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3997 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3998 if (getSourceType() != expectedSourceType &&
3999 getSourceType() != expectedRankReducedSourceType) {
4000 return emitOpError("source type "
4001 "mismatch: "
4002 "expected ")
4003 << expectedSourceType << " or its rank-reduced variant "
4004 << expectedRankReducedSourceType << " (got: " << getSourceType()
4005 << ")";
4006 }
4007
4008 return success();
4009}
4010
4011//===----------------------------------------------------------------------===//
4012// SplatOp
4013//===----------------------------------------------------------------------===//
4014
4015void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4016 Type aggregateType, ValueRange dynamicSizes) {
4017 build(builder, result, aggregateType, element, dynamicSizes);
4018}
4019
4020void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4021 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4022 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4023 build(builder, result, aggregateType, element, dynamicSizes);
4024}
4025
4026void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4027 ArrayRef<OpFoldResult> sizes) {
4028 SmallVector<int64_t> staticShape;
4029 SmallVector<Value> dynamicSizes;
4030 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4031 build(builder, result, element, staticShape, dynamicSizes);
4032}
4033
4034void SplatOp::getAsmResultNames(
4035 function_ref<void(Value, StringRef)> setNameFn) {
4036 setNameFn(getResult(), "splat");
4037}
4038
4039LogicalResult SplatOp::verify() {
4040 return verifyDynamicDimensionCount(getOperation(), getType(),
4041 getDynamicSizes());
4042}
4043
4044LogicalResult
4045SplatOp::reifyResultShapes(OpBuilder &builder,
4046 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4047 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4048 unsigned ctr = 0;
4049 for (int64_t i = 0; i < getType().getRank(); ++i) {
4050 if (getType().isDynamicDim(i)) {
4051 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4052 } else {
4053 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4054 }
4055 }
4056 return success();
4057}
4058
4059OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4060 auto constOperand = adaptor.getInput();
4061 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4062 return {};
4063
4064 // Do not fold if the splat is not statically shaped
4065 if (!getType().hasStaticShape())
4066 return {};
4067
4068 // SplatElementsAttr::get treats single value for second arg as being a
4069 // splat.
4070 return SplatElementsAttr::get(getType(), {constOperand});
4071}
4072
4073//===----------------------------------------------------------------------===//
4074// Common Canonicalizers and Folders.
4075//===----------------------------------------------------------------------===//
4076static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4077 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4078 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4079 // might need special handling of attached regions.
4080 if (isa<InsertSliceOp>(op.getOperation()) ||
4081 isa<LoopLikeOpInterface>(op.getOperation()))
4082 return false;
4083
4085}
4086
4087/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4088/// the `tensor.cast` has source that is more static than the consuming op.
4089///
4090/// Example:
4091/// ```mlir
4092/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4093/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4094/// ```
4095///
4096/// folds into:
4097///
4098/// ```mlir
4099/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4100/// ```
4101/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4102/// can add the pattern to their canonicalizers.
4104 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4106 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4107
4108 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4109 PatternRewriter &rewriter) const override {
4110
4111 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4112 // for that instead.
4113 if (!foldTensorCastPrecondition(op) ||
4114 isa<linalg::RelayoutOpInterface>(*op))
4115 return failure();
4116
4117 SmallVector<Type> newResultTypes(op->getResultTypes());
4118 SmallVector<Value> newOperands =
4119 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4120
4121 // Clone op
4122 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4123
4124 SmallVector<Value, 4> replacements;
4125 replacements.reserve(newOp->getNumResults());
4126 for (auto [oldResult, newResult] :
4127 llvm::zip(op->getResults(), newOp->getResults())) {
4128 if (newResult.getType() != oldResult.getType()) {
4129 replacements.push_back(tensor::CastOp::create(
4130 rewriter, op->getLoc(), oldResult.getType(), newResult));
4131 } else {
4132 replacements.push_back(newResult);
4133 }
4134 }
4135 rewriter.replaceOp(op, replacements);
4136
4137 return success();
4138 }
4139};
4140
4141//===----------------------------------------------------------------------===//
4142// TensorDialect
4143//===----------------------------------------------------------------------===//
4144
4145void TensorDialect::getCanonicalizationPatterns(
4146 RewritePatternSet &results) const {
4147 results.add<FoldTensorCastProducerOp>(getContext());
4148}
4149
4150//===----------------------------------------------------------------------===//
4151// TableGen'd op method definitions
4152//===----------------------------------------------------------------------===//
4153
4154#define GET_OP_CLASSES
4155#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
Attributes are known-constant values of operations.
Definition Attributes.h:25
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:58
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:76
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:67
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:90
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.