MLIR  17.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include <algorithm>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::tensor;
35 
36 /// Materialize a single constant operation from a given attribute value with
37 /// the desired resultant type.
39  Attribute value, Type type,
40  Location loc) {
41  if (arith::ConstantOp::isBuildableWith(value, type))
42  return builder.create<arith::ConstantOp>(loc, value, type);
43  if (complex::ConstantOp::isBuildableWith(value, type))
44  return builder.create<complex::ConstantOp>(loc, type,
45  value.cast<ArrayAttr>());
46  return nullptr;
47 }
48 
50  Location loc, Value value) {
51  auto tensorType = value.getType().cast<RankedTensorType>();
53  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
54  if (tensorType.isDynamicDim(i)) {
55  Value size = builder.create<tensor::DimOp>(loc, value, i);
56  result.push_back(size);
57  } else {
58  result.push_back(builder.getIndexAttr(tensorType.getDimSize(i)));
59  }
60  }
61  return result;
62 }
63 
65  OpResult opResult) {
66  auto tensorType = opResult.getType().dyn_cast<TensorType>();
67  assert(tensorType && "expected tensor type");
68 
69  // If the op has a destination, it implements DestinationStyleOpInterface and
70  // we can query the destination operand from that interface.
71  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
72  if (destOp)
73  return destOp.getTiedOpOperand(opResult)->get();
74 
75  // Otherwise, create a new destination tensor with the same shape.
77  b.setInsertionPoint(opResult.getDefiningOp());
78 
79  // Compute sizes.
80  SmallVector<OpFoldResult> mixedSizes;
81  if (!tensorType.hasStaticShape()) {
82  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
83  ReifiedRankedShapedTypeDims reifiedShapes;
84  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
85  return failure();
86  mixedSizes = reifiedShapes[opResult.getResultNumber()];
87  } else {
88  // Static shape: Take static sizes directly.
89  for (int64_t sz : tensorType.getShape())
90  mixedSizes.push_back(b.getIndexAttr(sz));
91  }
92 
93  // Create empty tensor.
94  Value emptyTensor =
95  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
96  return emptyTensor;
97 }
98 
100  Operation *op,
101  SmallVector<Value> &result) {
102  for (OpResult opResult : op->getResults()) {
103  if (opResult.getType().isa<TensorType>()) {
104  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
105  if (failed(destination))
106  return failure();
107  result.push_back(*destination);
108  }
109  }
110  return success();
111 }
112 
113 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
114 /// rank-extending tensor.insert_slice op.
115 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
116  ArrayRef<OpFoldResult> mixedSizes) {
117  llvm::SmallBitVector droppedDims(mixedSizes.size());
118  int64_t shapePos = 0;
119 
120  for (const auto &size : enumerate(mixedSizes)) {
121  // Rank-reduced dims must have a static unit dimension.
122  bool isStaticUnitSize =
123  size.value().is<Attribute>() &&
124  size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
125 
126  if (shapePos == static_cast<int64_t>(reducedShape.size())) {
127  // There are no more dims in the reduced shape. All remaining sizes must
128  // be rank-reduced dims.
129  assert(isStaticUnitSize && "expected unit dim");
130  droppedDims.set(size.index());
131  continue;
132  }
133 
134  // Dim is preserved if the size is not a static 1.
135  if (!isStaticUnitSize) {
136  ++shapePos;
137  continue;
138  }
139 
140  // Dim is preserved if the reduced shape dim is also 1.
141  if (reducedShape[shapePos] == 1) {
142  ++shapePos;
143  continue;
144  }
145 
146  // Otherwise: Dim is dropped.
147  droppedDims.set(size.index());
148  }
149 
150  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
151  "dimension mismatch");
152  return droppedDims;
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // CastOp
157 //===----------------------------------------------------------------------===//
158 
159 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
160  setNameFn(getResult(), "cast");
161 }
162 
163 /// Returns true if `target` is a ranked tensor type that preserves static
164 /// information available in the `source` ranked tensor type.
166  auto sourceType = source.dyn_cast<RankedTensorType>();
167  auto targetType = target.dyn_cast<RankedTensorType>();
168 
169  // Requires RankedTensorType.
170  if (!sourceType || !targetType)
171  return false;
172 
173  // Requires same elemental type.
174  if (sourceType.getElementType() != targetType.getElementType())
175  return false;
176 
177  // Requires same rank.
178  if (sourceType.getRank() != targetType.getRank())
179  return false;
180 
181  // If cast is towards more static sizes along any dimension, don't fold.
182  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
183  if (!ShapedType::isDynamic(std::get<0>(t)) &&
184  ShapedType::isDynamic(std::get<1>(t)))
185  return false;
186  }
187 
188  return true;
189 }
190 
191 /// Determines whether tensor::CastOp casts to a more dynamic version of the
192 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
193 /// implement canonicalization patterns for ops in different dialects that may
194 /// consume the results of tensor.cast operations. Such foldable tensor.cast
195 /// operations are typically inserted as `slice` ops and are canonicalized,
196 /// to preserve the type compatibility of their uses.
197 ///
198 /// Returns true when all conditions are met:
199 /// 1. source and result are ranked tensors with same element type and rank.
200 /// 2. the tensor type has more static information than the result
201 ///
202 /// Example:
203 /// ```mlir
204 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
205 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
206 /// ```
207 ///
208 /// folds into:
209 ///
210 /// ```mlir
211 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
212 /// ```
214  if (!castOp)
215  return false;
216 
217  // Can fold if the source of cast has at least as much static information as
218  // its results.
219  return preservesStaticInformation(castOp.getType(),
220  castOp.getSource().getType());
221 }
222 
223 /// Determines whether the tensor::CastOp casts to a more static version of the
224 /// source tensor. This is useful to fold into a producing op and implement
225 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
226 /// being from different dialects. Returns true when all conditions are met:
227 /// 1. source and result and ranked tensors with same element type and rank.
228 /// 2. the result type has more static information than the source.
229 ///
230 /// Example:
231 /// ```mlir
232 /// %1 = producer ... : tensor<?x?xf32>
233 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
234 /// ```
235 ///
236 /// can be canonicalized to :
237 ///
238 /// ```mlir
239 /// %2 = producer ... : tensor<8x16xf32>
240 /// ```
241 /// Not all ops might be canonicalizable this way, but for those that can be,
242 /// this method provides a check that it is worth doing the canonicalization.
244  if (!castOp)
245  return false;
246  return preservesStaticInformation(castOp.getSource().getType(),
247  castOp.getType());
248 }
249 
250 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
251 /// that can be folded.
253  bool folded = false;
254  for (OpOperand &operand : op->getOpOperands()) {
255  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
256  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
257  operand.set(castOp.getOperand());
258  folded = true;
259  }
260  }
261  return success(folded);
262 }
263 
264 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
265  if (inputs.size() != 1 || outputs.size() != 1)
266  return false;
267  Type a = inputs.front(), b = outputs.front();
268  auto aT = a.dyn_cast<TensorType>();
269  auto bT = b.dyn_cast<TensorType>();
270  if (!aT || !bT)
271  return false;
272 
273  if (aT.getElementType() != bT.getElementType())
274  return false;
275 
276  return succeeded(verifyCompatibleShape(aT, bT));
277 }
278 
279 /// Compute a TensorType that has the joined shape knowledge of the two
280 /// given TensorTypes. The element types need to match.
282  assert(one.getElementType() == two.getElementType());
283 
284  if (!one.hasRank())
285  return two;
286  if (!two.hasRank())
287  return one;
288 
289  int64_t rank = one.getRank();
290  if (rank != two.getRank())
291  return {};
292 
294  join.reserve(rank);
295  for (int64_t i = 0; i < rank; ++i) {
296  if (one.isDynamicDim(i)) {
297  join.push_back(two.getDimSize(i));
298  continue;
299  }
300  if (two.isDynamicDim(i)) {
301  join.push_back(one.getDimSize(i));
302  continue;
303  }
304  if (one.getDimSize(i) != two.getDimSize(i))
305  return {};
306  join.push_back(one.getDimSize(i));
307  }
308  return RankedTensorType::get(join, one.getElementType());
309 }
310 
311 namespace {
312 
313 /// Replaces chains of two tensor.cast operations by a single tensor.cast
314 /// operation if doing so does not remove runtime constraints.
315 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
317 
318  LogicalResult matchAndRewrite(CastOp tensorCast,
319  PatternRewriter &rewriter) const final {
320  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
321 
322  if (!tensorCastOperand)
323  return failure();
324 
325  auto sourceType =
326  tensorCastOperand.getOperand().getType().cast<TensorType>();
327  auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
328  auto resultType = tensorCast.getType().cast<TensorType>();
329 
330  // We can remove the intermediate cast if joining all three produces the
331  // same result as just joining the source and result shapes.
332  auto firstJoin =
333  joinShapes(joinShapes(sourceType, intermediateType), resultType);
334 
335  // The join might not exist if the cast sequence would fail at runtime.
336  if (!firstJoin)
337  return failure();
338 
339  // The newJoin always exists if the above join exists, it might just contain
340  // less information. If so, we cannot drop the intermediate cast, as doing
341  // so would remove runtime checks.
342  auto newJoin = joinShapes(sourceType, resultType);
343  if (firstJoin != newJoin)
344  return failure();
345 
346  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
347  tensorCastOperand.getOperand());
348  return success();
349  }
350 };
351 
352 /// Fold tensor.cast into tesor.extract_slice producer.
353 /// Example:
354 /// ```
355 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
356 /// tensor<128x512xf32> to tensor<?x512xf32>
357 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
358 /// ```
359 /// ->
360 /// ```
361 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
362 /// tensor<128x512xf32> to tensor<16x512xf32>
363 /// ```
364 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
366 
367  LogicalResult matchAndRewrite(CastOp tensorCast,
368  PatternRewriter &rewriter) const final {
369  auto extractOperand =
370  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
371 
372  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
373  tensorCast.getType().getShape() == tensorCast.getSource()
374  .getType()
375  .cast<RankedTensorType>()
376  .getShape())
377  return failure();
378 
379  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
380  auto dimMask = computeRankReductionMask(
381  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
382  size_t dimIndex = 0;
383  for (size_t i = 0, e = sizes.size(); i < e; i++) {
384  if (dimMask && dimMask->count(i))
385  continue;
386  int64_t dim = tensorCast.getType().getShape()[dimIndex++];
387  if (ShapedType::isDynamic(dim))
388  continue;
389  sizes[i] = rewriter.getIndexAttr(dim);
390  }
391 
392  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
393  tensorCast, tensorCast.getType().cast<RankedTensorType>(),
394  extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
395  extractOperand.getMixedStrides());
396  return success();
397  }
398 };
399 
400 } // namespace
401 
402 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
403  MLIRContext *context) {
404  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // DimOp
409 //===----------------------------------------------------------------------===//
410 
411 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
412  setNameFn(getResult(), "dim");
413 }
414 
415 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
416  int64_t index) {
417  auto loc = result.location;
418  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
419  build(builder, result, source, indexValue);
420 }
421 
422 std::optional<int64_t> DimOp::getConstantIndex() {
423  return getConstantIntValue(getIndex());
424 }
425 
426 Speculation::Speculatability DimOp::getSpeculatability() {
427  auto constantIndex = getConstantIndex();
428  if (!constantIndex)
430 
431  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
432  if (!rankedSourceType)
434 
435  // The verifier rejects operations that violate this assertion.
436  assert(constantIndex < rankedSourceType.getRank());
438 }
439 
440 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
441  // All forms of folding require a known index.
442  auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
443  if (!index)
444  return {};
445 
446  // Folding for unranked types (UnrankedTensorType) is not supported.
447  auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
448  if (!tensorType)
449  return {};
450 
451  // Out of bound indices produce undefined behavior but are still valid IR.
452  // Don't choke on them.
453  int64_t indexVal = index.getInt();
454  if (indexVal < 0 || indexVal >= tensorType.getRank())
455  return {};
456 
457  // Fold if the shape extent along the given index is known.
458  if (!tensorType.isDynamicDim(index.getInt())) {
459  Builder builder(getContext());
460  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
461  }
462 
463  Operation *definingOp = getSource().getDefiningOp();
464 
465  // Fold dim to the operand of tensor.generate.
466  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
467  auto resultType =
468  fromElements.getResult().getType().cast<RankedTensorType>();
469  // The case where the type encodes the size of the dimension is handled
470  // above.
471  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
472 
473  // Find the operand of the fromElements that corresponds to this index.
474  auto dynExtents = fromElements.getDynamicExtents().begin();
475  for (auto dim : resultType.getShape().take_front(index.getInt()))
476  if (ShapedType::isDynamic(dim))
477  dynExtents++;
478 
479  return Value{*dynExtents};
480  }
481 
482  // The size at the given index is now known to be a dynamic size.
483  unsigned unsignedIndex = index.getValue().getZExtValue();
484 
485  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
486  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
487  // `resolve-shaped-type-result-dims` pass.
488  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
489  sliceOp.isDynamicSize(unsignedIndex)) {
490  return {sliceOp.getDynamicSize(unsignedIndex)};
491  }
492  }
493 
494  // dim(cast) -> dim
495  if (succeeded(foldTensorCast(*this)))
496  return getResult();
497 
498  return {};
499 }
500 
501 namespace {
502 /// Fold dim of a cast into the dim of the source of the tensor cast.
503 struct DimOfCastOp : public OpRewritePattern<DimOp> {
505 
506  LogicalResult matchAndRewrite(DimOp dimOp,
507  PatternRewriter &rewriter) const override {
508  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
509  if (!castOp)
510  return failure();
511  Value newSource = castOp.getOperand();
512  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
513  return success();
514  }
515 };
516 } // namespace
517 
518 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
519  MLIRContext *context) {
520  results.add<DimOfCastOp>(context);
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // EmptyOp
525 //===----------------------------------------------------------------------===//
526 
527 void EmptyOp::build(OpBuilder &builder, OperationState &result,
528  ArrayRef<int64_t> staticShape, Type elementType,
529  Attribute encoding) {
530  assert(all_of(staticShape,
531  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
532  "expected only static sizes");
533  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
534 }
535 
536 void EmptyOp::build(OpBuilder &builder, OperationState &result,
537  ArrayRef<int64_t> staticShape, Type elementType,
538  ValueRange dynamicSizes, Attribute encoding) {
539  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
540  build(builder, result, tensorType, dynamicSizes);
541 }
542 
543 void EmptyOp::build(OpBuilder &builder, OperationState &result,
544  ArrayRef<OpFoldResult> sizes, Type elementType,
545  Attribute encoding) {
546  SmallVector<int64_t> staticShape;
547  SmallVector<Value> dynamicSizes;
548  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
549  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
550 }
551 
553  if (getType().getNumDynamicDims() !=
554  static_cast<int64_t>(getDynamicSizes().size()))
555  return emitOpError("incorrect number of dynamic sizes, has ")
556  << getDynamicSizes().size() << ", expected "
557  << getType().getNumDynamicDims();
558  return success();
559 }
560 
563  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
564  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
565  unsigned ctr = 0;
566  for (int64_t i = 0; i < getType().getRank(); ++i) {
567  if (getType().isDynamicDim(i)) {
568  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
569  } else {
570  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
571  }
572  }
573  return success();
574 }
575 
576 Value EmptyOp::getDynamicSize(unsigned idx) {
577  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
578  unsigned ctr = 0;
579  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
580  if (getType().isDynamicDim(i))
581  ++ctr;
582  return getDynamicSizes()[ctr];
583 }
584 
587  unsigned ctr = 0;
588  OpBuilder b(getContext());
589  for (int64_t i = 0; i < getType().getRank(); ++i) {
590  if (getType().isDynamicDim(i)) {
591  result.push_back(getDynamicSizes()[ctr++]);
592  } else {
593  result.push_back(b.getIndexAttr(getType().getShape()[i]));
594  }
595  }
596  return result;
597 }
598 
599 namespace {
600 /// Change the type of the result of a `tensor.empty` by making the result
601 /// type statically sized along dimensions that in the original operation were
602 /// defined as dynamic, but the size was defined using a `constant` op. For
603 /// example
604 ///
605 /// %c5 = arith.constant 5: index
606 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
607 ///
608 /// to
609 ///
610 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
611 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
613 
614  LogicalResult matchAndRewrite(EmptyOp op,
615  PatternRewriter &rewriter) const override {
616  SmallVector<int64_t> staticShape(op.getType().getShape().begin(),
617  op.getType().getShape().end());
618  SmallVector<Value> dynamicSizes;
619 
620  // Compute new static and dynamic sizes.
621  unsigned ctr = 0;
622  bool changedType = false;
623  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
624  if (op.getType().isDynamicDim(i)) {
625  Value dynamicSize = op.getDynamicSizes()[ctr++];
626  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
627  if (cst.has_value()) {
628  staticShape[i] = *cst;
629  changedType = true;
630  } else {
631  dynamicSizes.push_back(dynamicSize);
632  }
633  }
634  }
635 
636  // Stop here if no dynamic size was promoted to static.
637  if (!changedType)
638  return failure();
639 
640  auto tensorType = RankedTensorType::get(
641  staticShape, op.getType().getElementType(), op.getType().getEncoding());
642  auto newOp =
643  rewriter.create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
644  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
645  return success();
646  }
647 };
648 
649 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
651 
652  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
653  PatternRewriter &rewriter) const override {
654  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
655  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
656  if (!emptyTensorOp || !maybeConstantIndex)
657  return failure();
658  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
659  return failure();
660  rewriter.replaceOp(dimOp,
661  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
662  return success();
663  }
664 };
665 
666 /// Canonicalize
667 ///
668 /// ```mlir
669 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
670 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
671 /// ```
672 ///
673 /// into
674 ///
675 /// ```mlir
676 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
677 /// ```
678 ///
679 /// This assumes the input program is correct in terms of its shape. So it is
680 /// safe to assume that `%d0` is in fact 4.
681 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
683 
684  LogicalResult matchAndRewrite(CastOp castOp,
685  PatternRewriter &rewriter) const override {
686  if (!canFoldIntoProducerOp(castOp))
687  return failure();
688  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
689  if (!producer)
690  return failure();
691 
692  auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
693  ArrayRef<int64_t> resultShape = resultType.getShape();
694  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
695  SmallVector<OpFoldResult> newMixedSizes;
696  newMixedSizes.reserve(currMixedSizes.size());
697  assert(resultShape.size() == currMixedSizes.size() &&
698  "mismatch in result shape and sizes of empty op");
699  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
700  int64_t newDim = std::get<0>(it);
701  OpFoldResult currDim = std::get<1>(it);
702  // Case 1: The empty tensor dim is static. Check that the tensor cast
703  // result dim matches.
704  if (auto attr = currDim.dyn_cast<Attribute>()) {
705  if (ShapedType::isDynamic(newDim) ||
706  newDim != attr.cast<IntegerAttr>().getInt()) {
707  // Something is off, the cast result shape cannot be more dynamic
708  // than the empty tensor result shape (enforced by
709  // `canFoldIntoProducer`). Abort for now.
710  return rewriter.notifyMatchFailure(
711  producer, "mismatch in static value of shape of empty tensor "
712  "result and cast result");
713  }
714  newMixedSizes.push_back(attr);
715  continue;
716  }
717 
718  // Case 2 : The tensor cast shape is static, but empty tensor result
719  // shape is dynamic.
720  if (!ShapedType::isDynamic(newDim)) {
721  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
722  continue;
723  }
724 
725  // Case 3 : The tensor cast shape is dynamic and empty tensor result
726  // shape is dynamic. Use the dynamic value from the empty tensor op.
727  newMixedSizes.push_back(currDim);
728  }
729 
730  // TODO: Do not drop tensor encoding.
731  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
732  resultType.getElementType());
733  return success();
734  }
735 };
736 
737 } // namespace
738 
739 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
740  MLIRContext *context) {
741  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
742  ReplaceEmptyTensorStaticShapeDims>(context);
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // ExtractOp
747 //===----------------------------------------------------------------------===//
748 
749 namespace {
750 
751 /// Canonicalizes the pattern of the form
752 ///
753 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
754 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
755 ///
756 /// to
757 ///
758 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
759 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
761 
762  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
763  PatternRewriter &rewriter) const final {
764  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
765  if (!tensorCast)
766  return failure();
767  if (!tensorCast.getSource().getType().isa<RankedTensorType>())
768  return failure();
769  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
770  extract, tensorCast.getSource(), extract.getIndices());
771  return success();
772  }
773 };
774 
775 } // namespace
776 
777 void ExtractOp::getAsmResultNames(
778  function_ref<void(Value, StringRef)> setNameFn) {
779  setNameFn(getResult(), "extracted");
780 }
781 
783  // Verify the # indices match if we have a ranked type.
784  auto tensorType = getTensor().getType().cast<RankedTensorType>();
785  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
786  return emitOpError("incorrect number of indices for extract_element");
787  return success();
788 }
789 
790 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
791  // If this is a splat elements attribute, simply return the value. All of
792  // the elements of a splat attribute are the same.
793  if (Attribute tensor = adaptor.getTensor())
794  if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
795  return splatTensor.getSplatValue<Attribute>();
796 
797  // Collect the constant indices into the tensor.
798  SmallVector<uint64_t, 8> indices;
799  for (Attribute indice : adaptor.getIndices()) {
800  if (!indice || !indice.isa<IntegerAttr>())
801  return {};
802  indices.push_back(indice.cast<IntegerAttr>().getInt());
803  }
804 
805  // Fold extract(from_elements(...)).
806  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
807  auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
808  auto rank = tensorType.getRank();
809  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
810  "rank mismatch");
811  int flatIndex = 0;
812  int stride = 1;
813  for (int i = rank - 1; i >= 0; --i) {
814  if (i < rank - 1)
815  stride *= tensorType.getDimSize(i);
816  flatIndex += indices[i] * stride;
817  }
818  // Prevent out of bounds accesses. This can happen in invalid code that
819  // will never execute.
820  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
821  flatIndex < 0)
822  return {};
823  return fromElementsOp.getElements()[flatIndex];
824  }
825 
826  // If this is an elements attribute, query the value at the given indices.
827  if (Attribute tensor = adaptor.getTensor()) {
828  auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
829  if (elementsAttr && elementsAttr.isValidIndex(indices))
830  return elementsAttr.getValues<Attribute>()[indices];
831  }
832 
833  return {};
834 }
835 
836 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
837  MLIRContext *context) {
838  results.add<ExtractFromTensorCast>(context);
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // FromElementsOp
843 //===----------------------------------------------------------------------===//
844 
845 void FromElementsOp::getAsmResultNames(
846  function_ref<void(Value, StringRef)> setNameFn) {
847  setNameFn(getResult(), "from_elements");
848 }
849 
850 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
851  Type resultType, ValueRange elements) {
852  result.addOperands(elements);
853  result.addTypes(resultType);
854 }
855 
856 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
857  ValueRange elements) {
858  assert(!elements.empty() && "expected at least one element");
859  Type resultType = RankedTensorType::get(
860  {static_cast<int64_t>(elements.size())}, elements.front().getType());
861  build(builder, result, resultType, elements);
862 }
863 
864 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
865  if (!llvm::is_contained(adaptor.getElements(), nullptr))
866  return DenseElementsAttr::get(getType(), adaptor.getElements());
867  return {};
868 }
869 
870 namespace {
871 
872 // Pushes the index_casts that occur before extractions to after the extract.
873 // This minimizes type conversion in some cases and enables the extract
874 // canonicalizer. This changes:
875 //
876 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
877 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
878 //
879 // to the following:
880 //
881 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
882 // %cast = arith.index_cast %extract : i32 to index
883 //
884 // to just %element.
885 //
886 // Consider expanding this to a template and handle all tensor cast
887 // operations.
888 struct ExtractElementFromIndexCast
889  : public OpRewritePattern<tensor::ExtractOp> {
891 
892  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
893  PatternRewriter &rewriter) const final {
894  Location loc = extract.getLoc();
895  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
896  if (!indexCast)
897  return failure();
898 
899  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
900 
901  auto newExtract = rewriter.create<tensor::ExtractOp>(
902  loc, elementTy, indexCast.getIn(), extract.getIndices());
903 
904  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
905  newExtract);
906 
907  return success();
908  }
909 };
910 
911 } // namespace
912 
913 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
914  MLIRContext *context) {
915  results.add<ExtractElementFromIndexCast>(context);
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // GatherOp
920 //===----------------------------------------------------------------------===//
921 
922 void GatherOp::getAsmResultNames(
923  function_ref<void(Value, StringRef)> setNameFn) {
924  setNameFn(getResult(), "gather");
925 }
926 
927 /// Return the inferred result type for a gatherOp where:
928 /// - sourceType is the type of the source tensor gathered from
929 /// - indicesType is the type of the indices used to gather
930 /// - gatherDims are the dims along which the gather occurs.
931 /// Return a full rank or ranked-reduced variant of the type depending on
932 /// the value of rankReduced.
933 ///
934 /// The leading dimensions of the index tensor give the result tensor its
935 /// leading dimensions.
936 /// The trailing dimensions of the result tensor are obtained from the source
937 /// tensor by setting the dimensions specified in gather_dims to `1` (if
938 /// rankedReduced is false), or skipping them (otherwise).
939 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
940  RankedTensorType indicesType,
941  ArrayRef<int64_t> gatherDims,
942  bool rankReduced) {
943  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
944  resultShape.reserve(resultShape.size() + sourceType.getRank());
945  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
946  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
947  if (!rankReduced)
948  resultShape.push_back(1);
949  continue;
950  }
951  resultShape.push_back(sourceType.getDimSize(idx));
952  }
953  return RankedTensorType::Builder(sourceType).setShape(resultShape);
954 }
955 
956 static LogicalResult
958  StringRef gatherOrScatter, StringRef sourceOrDest) {
959  if (dims.empty())
960  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
961 
962  int64_t numGatherDims = dims.size();
963  if (numGatherDims > rank)
964  return op->emitOpError(gatherOrScatter)
965  << "_dims overflow " << sourceOrDest << " rank";
966  for (int64_t val : dims) {
967  if (val < 0)
968  return op->emitOpError(gatherOrScatter)
969  << "_dims value must be non-negative";
970  if (val >= rank)
971  return op->emitOpError(gatherOrScatter)
972  << "_dims value must be smaller than " << sourceOrDest << " rank";
973  }
974  for (int64_t i = 1; i < numGatherDims; ++i) {
975  if (dims[i - 1] >= dims[i])
976  return op->emitOpError(gatherOrScatter)
977  << "_dims values must be strictly increasing";
978  }
979  return success();
980 }
981 
983  int64_t sourceRank = getSourceType().getRank();
984  ArrayRef<int64_t> gatherDims = getGatherDims();
985  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
986  "gather", "source")))
987  return failure();
988 
989  RankedTensorType expectedResultType = GatherOp::inferResultType(
990  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
991  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
992  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
993  if (getResultType() != expectedResultType &&
994  getResultType() != expectedRankReducedResultType) {
995  return emitOpError("result type "
996  "mismatch: "
997  "expected ")
998  << expectedResultType << " or its rank-reduced variant "
999  << expectedRankReducedResultType << " (got: " << getResultType()
1000  << ")";
1001  }
1002 
1003  return success();
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // InsertOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 void InsertOp::getAsmResultNames(
1011  function_ref<void(Value, StringRef)> setNameFn) {
1012  setNameFn(getResult(), "inserted");
1013 }
1014 
1016  // Verify the # indices match if we have a ranked type.
1017  auto destType = getDest().getType().cast<RankedTensorType>();
1018  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1019  return emitOpError("incorrect number of indices");
1020  return success();
1021 }
1022 
1023 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1024  Attribute scalar = adaptor.getScalar();
1025  Attribute dest = adaptor.getDest();
1026  if (scalar && dest)
1027  if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
1028  if (scalar == splatDest.getSplatValue<Attribute>())
1029  return dest;
1030  return {};
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // GenerateOp
1035 //===----------------------------------------------------------------------===//
1036 
1037 void GenerateOp::getAsmResultNames(
1038  function_ref<void(Value, StringRef)> setNameFn) {
1039  setNameFn(getResult(), "generated");
1040 }
1041 
1043  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1044  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1045  int idx = 0;
1046  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1047  if (getType().isDynamicDim(dim)) {
1048  reifiedReturnShapes[0][dim] = getOperand(idx++);
1049  } else {
1050  reifiedReturnShapes[0][dim] =
1051  builder.getIndexAttr(getType().getDimSize(dim));
1052  }
1053  }
1054  return success();
1055 }
1056 
1058  // Ensure that the tensor type has as many dynamic dimensions as are
1059  // specified by the operands.
1060  RankedTensorType resultTy = getType().cast<RankedTensorType>();
1061  if (getNumOperands() != resultTy.getNumDynamicDims())
1062  return emitError("must have as many index operands as dynamic extents "
1063  "in the result type");
1064 
1065  return success();
1066 }
1067 
1068 LogicalResult GenerateOp::verifyRegions() {
1069  RankedTensorType resultTy = getType().cast<RankedTensorType>();
1070  // Ensure that region arguments span the index space.
1071  if (!llvm::all_of(getBody().getArgumentTypes(),
1072  [](Type ty) { return ty.isIndex(); }))
1073  return emitError("all body arguments must be index");
1074  if (getBody().getNumArguments() != resultTy.getRank())
1075  return emitError("must have one body argument per input dimension");
1076 
1077  // Ensure that the region yields an element of the right type.
1078  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1079 
1080  if (yieldOp.getValue().getType() != resultTy.getElementType())
1081  return emitOpError(
1082  "body must be terminated with a `yield` operation of the tensor "
1083  "element type");
1084 
1085  return success();
1086 }
1087 
1088 void GenerateOp::build(
1089  OpBuilder &b, OperationState &result, Type resultTy,
1090  ValueRange dynamicExtents,
1091  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1092  build(b, result, resultTy, dynamicExtents);
1093 
1094  // Build and populate body.
1095  OpBuilder::InsertionGuard guard(b);
1096  Region *bodyRegion = result.regions.front().get();
1097  auto rank = resultTy.cast<RankedTensorType>().getRank();
1098  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1099  SmallVector<Location, 2> argumentLocs(rank, result.location);
1100  Block *bodyBlock =
1101  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1102  bodyBuilder(b, result.location, bodyBlock->getArguments());
1103 }
1104 
1105 namespace {
1106 
1107 /// Canonicalizes tensor.generate operations with a constant
1108 /// operand into the equivalent operation with the operand expressed in the
1109 /// result type, instead. We also insert a type cast to make sure that the
1110 /// resulting IR is still well-typed.
1111 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1113 
1114  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
1115  PatternRewriter &rewriter) const final {
1116  auto resultType =
1117  tensorFromElements.getResult().getType().cast<RankedTensorType>();
1118 
1119  if (resultType.hasStaticShape())
1120  return failure();
1121 
1122  SmallVector<Value, 4> newOperands;
1123  SmallVector<int64_t, 4> newShape;
1124  auto operandsIt = tensorFromElements.getDynamicExtents().begin();
1125 
1126  for (int64_t dim : resultType.getShape()) {
1127  if (!ShapedType::isDynamic(dim)) {
1128  newShape.push_back(dim);
1129  continue;
1130  }
1131  APInt index;
1132  if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1133  newShape.push_back(ShapedType::kDynamic);
1134  newOperands.push_back(*operandsIt++);
1135  continue;
1136  }
1137  newShape.push_back(index.getSExtValue());
1138  operandsIt++;
1139  }
1140 
1141  if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
1142  return failure();
1143 
1144  auto loc = tensorFromElements.getLoc();
1145  auto newOp = rewriter.create<GenerateOp>(
1146  loc, RankedTensorType::get(newShape, resultType.getElementType()),
1147  newOperands);
1148  rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
1149  newOp.getBody().begin());
1150  rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
1151  newOp);
1152  return success();
1153  }
1154 };
1155 
1156 /// Canonicalizes the pattern of the form
1157 ///
1158 /// %tensor = tensor.generate %x {
1159 /// ^bb0(%arg0: index):
1160 /// <computation>
1161 /// yield %1 : index
1162 /// } : tensor<?xindex>
1163 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1164 ///
1165 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1166 /// tensor.generate operation has no side-effects.
1167 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1169 
1170  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1171  PatternRewriter &rewriter) const final {
1172  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1173  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1174  return failure();
1175 
1176  IRMapping mapping;
1177  Block *body = &tensorFromElements.getBody().front();
1178  mapping.map(body->getArguments(), extract.getIndices());
1179  for (auto &op : body->without_terminator())
1180  rewriter.clone(op, mapping);
1181 
1182  auto yield = cast<YieldOp>(body->getTerminator());
1183 
1184  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1185  return success();
1186  }
1187 };
1188 
1189 } // namespace
1190 
1191 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1192  MLIRContext *context) {
1193  // TODO: Move extract pattern to tensor::ExtractOp.
1194  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // RankOp
1199 //===----------------------------------------------------------------------===//
1200 
1201 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1202  setNameFn(getResult(), "rank");
1203 }
1204 
1205 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1206  // Constant fold rank when the rank of the operand is known.
1207  auto type = getOperand().getType();
1208  auto shapedType = type.dyn_cast<ShapedType>();
1209  if (shapedType && shapedType.hasRank())
1210  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1211  return IntegerAttr();
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // ReshapeOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 void ReshapeOp::getAsmResultNames(
1219  function_ref<void(Value, StringRef)> setNameFn) {
1220  setNameFn(getResult(), "reshape");
1221 }
1222 
1223 static int64_t getNumElements(ShapedType type) {
1224  int64_t numElements = 1;
1225  for (auto dim : type.getShape())
1226  numElements *= dim;
1227  return numElements;
1228 }
1229 
1231  TensorType operandType = getSource().getType().cast<TensorType>();
1232  TensorType resultType = getResult().getType().cast<TensorType>();
1233 
1234  if (operandType.getElementType() != resultType.getElementType())
1235  return emitOpError("element types of source and destination tensor "
1236  "types should be the same");
1237 
1238  int64_t shapeSize =
1239  getShape().getType().cast<RankedTensorType>().getDimSize(0);
1240  auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
1241  auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
1242 
1243  if (resultRankedType) {
1244  if (operandRankedType && resultRankedType.hasStaticShape() &&
1245  operandRankedType.hasStaticShape()) {
1246  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1247  return emitOpError("source and destination tensor should have the "
1248  "same number of elements");
1249  }
1250  if (ShapedType::isDynamic(shapeSize))
1251  return emitOpError("cannot use shape operand with dynamic length to "
1252  "reshape to statically-ranked tensor type");
1253  if (shapeSize != resultRankedType.getRank())
1254  return emitOpError(
1255  "length of shape operand differs from the result's tensor rank");
1256  }
1257  return success();
1258 }
1259 
1260 //===----------------------------------------------------------------------===//
1261 // Reassociative reshape ops
1262 //===----------------------------------------------------------------------===//
1263 
1264 void CollapseShapeOp::getAsmResultNames(
1265  function_ref<void(Value, StringRef)> setNameFn) {
1266  setNameFn(getResult(), "collapsed");
1267 }
1268 
1269 void ExpandShapeOp::getAsmResultNames(
1270  function_ref<void(Value, StringRef)> setNameFn) {
1271  setNameFn(getResult(), "expanded");
1272 }
1273 
1274 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1275  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1276  "invalid resultDim");
1277  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1278  if (llvm::is_contained(it.value(), resultDim))
1279  return it.index();
1280  llvm_unreachable("could not find reassociation group");
1281 }
1282 
1283 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1284  return getSymbolLessAffineMaps(getReassociationExprs());
1285 }
1286 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1287  return convertReassociationIndicesToExprs(getContext(),
1288  getReassociationIndices());
1289 }
1290 
1291 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1292  return getSymbolLessAffineMaps(getReassociationExprs());
1293 }
1294 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1295  return convertReassociationIndicesToExprs(getContext(),
1296  getReassociationIndices());
1297 }
1298 
1299 RankedTensorType CollapseShapeOp::inferCollapsedType(
1300  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1301  return inferCollapsedType(
1303  type.getContext(), reassociation)));
1304 }
1305 
1306 /// Compute the RankedTensorType obtained by applying `reassociation` to
1307 /// `type`.
1308 RankedTensorType
1309 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1310  ArrayRef<AffineMap> reassociation) {
1311  auto shape = type.getShape();
1312  SmallVector<int64_t, 4> newShape;
1313  newShape.reserve(reassociation.size());
1314 
1315  // Use the fact that reassociation is valid to simplify the logic: only use
1316  // each map's rank.
1317  assert(isReassociationValid(reassociation) && "invalid reassociation");
1318  unsigned currentDim = 0;
1319  for (AffineMap m : reassociation) {
1320  unsigned dim = m.getNumResults();
1321  auto band = shape.slice(currentDim, dim);
1322  int64_t size = 1;
1323  if (llvm::is_contained(band, ShapedType::kDynamic))
1324  size = ShapedType::kDynamic;
1325  else
1326  for (unsigned d = 0; d < dim; ++d)
1327  size *= shape[currentDim + d];
1328  newShape.push_back(size);
1329  currentDim += dim;
1330  }
1331 
1332  return RankedTensorType::get(newShape, type.getElementType());
1333 }
1334 
1335 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1336  ArrayRef<ReassociationIndices> reassociation,
1337  ArrayRef<NamedAttribute> attrs) {
1338  auto resultType = inferCollapsedType(
1339  src.getType().cast<RankedTensorType>(),
1341  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1342  build(b, result, resultType, src, attrs);
1343  result.addAttribute(getReassociationAttrStrName(),
1344  getReassociationIndicesAttribute(b, reassociation));
1345 }
1346 
1347 // Checks if types are the same, but ignoring encoding on ranked tensors.
1348 static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
1349  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
1350  if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
1351  return rtp1.getShape() == rtp2.getShape() &&
1352  rtp1.getElementType() == rtp2.getElementType();
1353  return false;
1354  }
1355  // Default implementation.
1356  return tp1 == tp2;
1357 }
1358 
1359 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1360  TensorReshapeOp, ExpandShapeOp>::value>
1361 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1362  RankedTensorType expandedType,
1363  RankedTensorType collapsedType) {
1364  if (failed(
1365  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1366  return failure();
1367 
1368  auto maps = op.getReassociationMaps();
1369  RankedTensorType expectedType =
1370  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1371  if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
1372  return op.emitOpError("expected collapsed type to be ")
1373  << expectedType << ", but got " << collapsedType;
1374  return success();
1375 }
1376 
1378  auto srcType = getSrcType();
1379  auto resultType = getResultType();
1380  if (srcType.getRank() >= resultType.getRank())
1381  return emitOpError("expected rank expansion, but found source rank ")
1382  << srcType.getRank() << " >= result rank " << resultType.getRank();
1383 
1384  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
1385 }
1386 
1388  auto srcType = getSrcType();
1389  auto resultType = getResultType();
1390  if (srcType.getRank() <= resultType.getRank())
1391  return emitOpError("expected rank reduction, but found source rank ")
1392  << srcType.getRank() << " <= result rank " << resultType.getRank();
1393 
1394  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1395 }
1396 
1397 namespace {
1398 /// Reshape of a splat constant can be replaced with a constant of the result
1399 /// type.
1400 template <typename TensorReshapeOp>
1401 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1403  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1404  PatternRewriter &rewriter) const override {
1405  DenseElementsAttr attr;
1406  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1407  return failure();
1408  if (!attr || !attr.isSplat())
1409  return failure();
1411  reshapeOp.getResultType(), attr.getRawData());
1412  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1413  return success();
1414  }
1415 };
1416 
1417 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1418 template <typename TensorReshapeOp>
1419 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1420 public:
1422 
1423  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1424  PatternRewriter &rewriter) const override {
1425  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1426  if (!splatOp)
1427  return failure();
1428 
1429  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1430  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1431  return success();
1432  }
1433 };
1434 
1435 /// Reshape of a FromElements can be replaced with a FromElements of the
1436 /// result type
1437 template <typename TensorReshapeOp>
1438 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1440  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1441  PatternRewriter &rewriter) const override {
1442  auto fromElements =
1443  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1444  if (!fromElements)
1445  return failure();
1446 
1447  auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
1448 
1449  if (!shapedTy.hasStaticShape())
1450  return failure();
1451 
1452  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1453  fromElements.getElements());
1454  return success();
1455  }
1456 };
1457 
1458 // Fold CastOp into CollapseShapeOp when adding static information.
1459 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1461 
1462  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1463  PatternRewriter &rewriter) const override {
1464  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1465  if (!tensor::canFoldIntoConsumerOp(castOp))
1466  return failure();
1467 
1468  RankedTensorType srcType =
1469  castOp.getSource().getType().cast<RankedTensorType>();
1470  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1471  srcType, collapseShapeOp.getReassociationMaps());
1472 
1473  if (newResultType == collapseShapeOp.getResultType()) {
1474  rewriter.updateRootInPlace(collapseShapeOp, [&]() {
1475  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1476  });
1477  } else {
1478  auto newOp = rewriter.create<CollapseShapeOp>(
1479  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1480  collapseShapeOp.getReassociation());
1481  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1482  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1483  }
1484  return success();
1485  }
1486 };
1487 
1488 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1490 
1491  LogicalResult matchAndRewrite(DimOp dimOp,
1492  PatternRewriter &rewriter) const override {
1493  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1494  if (!expandShapeOp)
1495  return failure();
1496 
1497  // Only constant dimension values are supported.
1498  std::optional<int64_t> dim = dimOp.getConstantIndex();
1499  if (!dim.has_value())
1500  return failure();
1501 
1502  // Skip static dims. These are folded to constant ops.
1503  TensorType resultType = expandShapeOp.getResultType();
1504  if (!resultType.isDynamicDim(*dim))
1505  return failure();
1506 
1507  // Find reassociation group that contains this result dimension.
1508  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1509 
1510  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1511  // ExpandShapeOp would be ambiguous.)
1512  int64_t product = 1;
1513  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1514  for (int64_t d : grp) {
1515  if (d != dim) {
1516  assert(!resultType.isDynamicDim(d) && "expected static dim");
1517  product *= resultType.getDimSize(d);
1518  }
1519  }
1520 
1521  // result dim size = src dim size / (product(other dims in reassoc group))
1522  Value srcDimSz =
1523  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1524  AffineExpr expr;
1525  bindSymbols(dimOp.getContext(), expr);
1526  rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, expr.floorDiv(product),
1527  srcDimSz);
1528  return success();
1529  }
1530 };
1531 
1532 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1534 
1535  LogicalResult matchAndRewrite(DimOp dimOp,
1536  PatternRewriter &rewriter) const override {
1537  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1538  if (!collapseShapeOp)
1539  return failure();
1540 
1541  // Only constant dimension values are supported.
1542  std::optional<int64_t> dim = dimOp.getConstantIndex();
1543  if (!dim.has_value())
1544  return failure();
1545 
1546  // Skip static dims. These are folded to constant ops.
1547  TensorType resultType = collapseShapeOp.getResultType();
1548  if (!resultType.isDynamicDim(*dim))
1549  return failure();
1550 
1551  // Get reassociation group of the result dimension.
1552  ReassociationIndices group =
1553  collapseShapeOp.getReassociationIndices()[*dim];
1554 
1555  // result dim size = product(dims in reassoc group)
1556  SmallVector<Value> srcDimSizes;
1558  AffineExpr product;
1559  for (const auto &it : llvm::enumerate(group)) {
1560  srcDimSizes.push_back(rewriter.create<DimOp>(
1561  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1562  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1563  product = product ? product * syms.back() : syms.back();
1564  }
1565  rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, product, srcDimSizes);
1566  return success();
1567  }
1568 };
1569 } // namespace
1570 
1571 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1572  MLIRContext *context) {
1575  FoldReshapeWithConstant<ExpandShapeOp>,
1576  FoldReshapeWithSplat<ExpandShapeOp>,
1577  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1578  FoldDimOfCollapseShape>(context);
1579 }
1580 
1581 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1582  MLIRContext *context) {
1583  results
1586  FoldReshapeWithConstant<CollapseShapeOp>,
1587  FoldReshapeWithSplat<CollapseShapeOp>,
1588  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1589  context);
1590 }
1591 
1592 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1593  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
1594  adaptor.getOperands());
1595 }
1596 
1597 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1598  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
1599  adaptor.getOperands());
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // ExtractSliceOp
1604 //===----------------------------------------------------------------------===//
1605 
1606 void ExtractSliceOp::getAsmResultNames(
1607  function_ref<void(Value, StringRef)> setNameFn) {
1608  setNameFn(getResult(), "extracted_slice");
1609 }
1610 
1611 /// An extract_slice result type can be inferred, when it is not
1612 /// rank-reduced, from the source type and the static representation of
1613 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
1614 RankedTensorType ExtractSliceOp::inferResultType(
1615  ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
1616  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
1617  // An extract_slice op may specify only a leading subset of offset/sizes/
1618  // strides in which case we complete with offset=0, sizes from memref type
1619  // and strides=1.
1620  assert(static_cast<int64_t>(staticSizes.size()) ==
1621  sourceShapedTensorType.getRank() &&
1622  "unexpected staticSizes not equal to rank of source");
1623  return RankedTensorType::get(staticSizes,
1624  sourceShapedTensorType.getElementType());
1625 }
1626 
1627 RankedTensorType ExtractSliceOp::inferResultType(
1628  ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
1630  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1631  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1632  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1633  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1634  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1635  return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
1636  staticSizes, staticStrides);
1637 }
1638 
1639 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
1640 /// number of sizes), drop as many size 1 as needed to produce an inferred
1641 /// type with the desired rank.
1642 ///
1643 /// Note that there may be multiple ways to compute this rank-reduced type:
1644 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
1645 ///
1646 /// To disambiguate, this function always drops the first 1 sizes occurrences.
1647 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1648  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1649  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1650  ArrayRef<int64_t> strides) {
1651  // Type inferred in the absence of rank-reducing behavior.
1652  auto inferredType =
1653  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1654  .cast<RankedTensorType>();
1655  int rankDiff = inferredType.getRank() - desiredResultRank;
1656  if (rankDiff > 0) {
1657  auto shape = inferredType.getShape();
1658  llvm::SmallBitVector dimsToProject =
1659  getPositionsOfShapeOne(rankDiff, shape);
1660  SmallVector<int64_t> projectedShape;
1661  // Best effort rank-reducing: drop 1s in order.
1662  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1663  if (!dimsToProject.test(pos))
1664  projectedShape.push_back(shape[pos]);
1665  inferredType =
1666  RankedTensorType::get(projectedShape, inferredType.getElementType());
1667  }
1668  return inferredType;
1669 }
1670 
1671 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1672  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1674  ArrayRef<OpFoldResult> strides) {
1675  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1676  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1677  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1678  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1679  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1680  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1681  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1682  staticStrides);
1683 }
1684 
1685 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1686 /// result type. If the type passed is nullptr, it is inferred.
1687 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1688  RankedTensorType resultType, Value source,
1689  ArrayRef<OpFoldResult> offsets,
1690  ArrayRef<OpFoldResult> sizes,
1691  ArrayRef<OpFoldResult> strides,
1692  ArrayRef<NamedAttribute> attrs) {
1693  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1694  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1695  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1696  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1697  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1698  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1699  // Structuring implementation this way avoids duplication between builders.
1700  if (!resultType) {
1701  resultType =
1702  ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1703  staticSizes, staticStrides)
1704  .cast<RankedTensorType>();
1705  }
1706  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1707  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1708  b.getDenseI64ArrayAttr(staticSizes),
1709  b.getDenseI64ArrayAttr(staticStrides));
1710  result.addAttributes(attrs);
1711 }
1712 
1713 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1714 /// result type.
1715 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1716  ArrayRef<OpFoldResult> offsets,
1717  ArrayRef<OpFoldResult> sizes,
1718  ArrayRef<OpFoldResult> strides,
1719  ArrayRef<NamedAttribute> attrs) {
1720  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1721 }
1722 
1723 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
1724 /// a Range vector.
1725 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1726  ArrayRef<Range> ranges,
1727  ArrayRef<NamedAttribute> attrs) {
1728  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
1729  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1730 }
1731 
1732 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
1733 /// the type passed is nullptr, it is inferred.
1734 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1735  RankedTensorType resultType, Value source,
1736  ValueRange offsets, ValueRange sizes,
1737  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1738  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1739  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1740  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1741  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1742  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1743  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1744  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1745 }
1746 
1747 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
1748 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1749  ValueRange offsets, ValueRange sizes,
1750  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1751  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1752 }
1753 
1754 template <typename OpTy>
1756  OpTy op, Type expectedType) {
1757  auto memrefType = expectedType.cast<ShapedType>();
1758  switch (result) {
1760  return success();
1762  return op.emitError("expected rank to be smaller or equal to ")
1763  << "the other rank. ";
1765  return op.emitError("expected type to be ")
1766  << expectedType << " or a rank-reduced version. (size mismatch) ";
1768  return op.emitError("expected element type to be ")
1769  << memrefType.getElementType();
1770  default:
1771  llvm_unreachable("unexpected extract_slice op verification result");
1772  }
1773 }
1774 
1775 /// Verifier for ExtractSliceOp.
1777  // Verify result type against inferred type.
1778  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1779  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1780  SliceVerificationResult result = isRankReducedType(expectedType, getType());
1781  return produceSliceErrorMsg(result, *this, expectedType);
1782 }
1783 
1784 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1786 }
1787 
1789 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
1790  ArrayRef<int64_t> desiredShape) {
1791  auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
1792  assert(sourceTensorType && "not a ranked tensor type");
1793  auto sourceShape = sourceTensorType.getShape();
1794  if (sourceShape.equals(desiredShape))
1795  return value;
1796  auto maybeRankReductionMask =
1797  mlir::computeRankReductionMask(sourceShape, desiredShape);
1798  if (!maybeRankReductionMask)
1799  return failure();
1801  b, loc, value,
1802  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
1803 }
1804 
1806  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1807  reifiedReturnShapes.resize(1);
1808  reifiedReturnShapes[0].reserve(getType().getRank());
1810  llvm::SmallBitVector droppedDims = getDroppedDims();
1811  for (const auto &size : enumerate(mixedSizes)) {
1812  if (droppedDims.test(size.index()))
1813  continue;
1814  reifiedReturnShapes[0].push_back(size.value());
1815  }
1816  return success();
1817 }
1818 
1819 namespace {
1820 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1821 /// This essentially pushes memref_cast past its consuming slice when
1822 /// `canFoldIntoConsumerOp` is true.
1823 ///
1824 /// Example:
1825 /// ```
1826 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1827 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1828 /// tensor<3x4xf32>
1829 /// ```
1830 /// is rewritten into:
1831 /// ```
1832 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1833 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1834 /// ```
1835 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1836 public:
1838 
1839  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1840  PatternRewriter &rewriter) const override {
1841  // Any constant operand, just return to let the constant folder kick in.
1842  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1843  return matchPattern(operand, matchConstantIndex());
1844  }))
1845  return failure();
1846 
1847  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
1848  if (!castOp)
1849  return failure();
1850 
1851  if (!canFoldIntoConsumerOp(castOp))
1852  return failure();
1853 
1854  /// Deduce the type of the result to use for the canonicalized operation.
1855  Location loc = sliceOp.getLoc();
1856  auto sliceOpType = sliceOp.getType();
1857  RankedTensorType resultType =
1858  ExtractSliceOp::inferCanonicalRankReducedResultType(
1859  sliceOpType.getRank(), sliceOp.getSourceType(),
1860  sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1861  sliceOp.getMixedStrides());
1862  Value newResult = rewriter.create<ExtractSliceOp>(
1863  loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
1864  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1865  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1866  if (newResult.getType() != sliceOpType)
1867  newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
1868  rewriter.replaceOp(sliceOp, newResult);
1869  return success();
1870  }
1871 };
1872 
1873 /// Slice elements from `values` into `outValues`. `counts` represents the
1874 /// numbers of elements to stride in the original values for each dimension.
1875 /// The output values can be used to construct a DenseElementsAttr.
1876 template <typename IterTy, typename ElemTy>
1877 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1878  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1879  ArrayRef<int64_t> strides,
1880  llvm::SmallVectorImpl<ElemTy> *outValues) {
1881  assert(offsets.size() == sizes.size());
1882  assert(offsets.size() == strides.size());
1883  if (offsets.empty())
1884  return;
1885 
1886  int64_t offset = offsets.front();
1887  int64_t size = sizes.front();
1888  int64_t stride = strides.front();
1889  if (offsets.size() == 1) {
1890  for (int64_t i = 0; i < size; ++i, offset += stride)
1891  outValues->push_back(*(values + offset));
1892 
1893  return;
1894  }
1895 
1896  for (int64_t i = 0; i < size; ++i, offset += stride) {
1897  auto begin = values + offset * counts.front();
1898  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1899  offsets.drop_front(), sizes.drop_front(),
1900  strides.drop_front(), outValues);
1901  }
1902 }
1903 
1904 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
1905 /// folded operation might introduce more constant data; Users can control
1906 /// their heuristics by the control function.
1907 class ConstantOpExtractSliceFolder final
1908  : public OpRewritePattern<ExtractSliceOp> {
1909 public:
1911 
1912  ConstantOpExtractSliceFolder(MLIRContext *context,
1914  : OpRewritePattern<ExtractSliceOp>(context),
1915  controlFn(std::move(controlFn)) {}
1916 
1917  LogicalResult matchAndRewrite(ExtractSliceOp op,
1918  PatternRewriter &rewriter) const override {
1919  DenseElementsAttr attr;
1920  if (!matchPattern(op.getSource(), m_Constant(&attr)))
1921  return failure();
1922 
1923  // A constant splat is handled by fold().
1924  if (attr.isSplat())
1925  return failure();
1926 
1927  // Dynamic result shape is not supported.
1928  auto sourceType = op.getSource().getType().cast<ShapedType>();
1929  auto resultType = op.getResult().getType().cast<ShapedType>();
1930  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1931  return failure();
1932 
1933  // Customized control over the folding.
1934  if (!controlFn(op))
1935  return failure();
1936 
1937  int64_t count = sourceType.getNumElements();
1938  if (count == 0)
1939  return failure();
1940 
1941  // Check if there are any dynamic parts, which are not supported.
1942  auto offsets = op.getStaticOffsets();
1943  if (llvm::is_contained(offsets, ShapedType::kDynamic))
1944  return failure();
1945  auto sizes = op.getStaticSizes();
1946  if (llvm::is_contained(sizes, ShapedType::kDynamic))
1947  return failure();
1948  auto strides = op.getStaticStrides();
1949  if (llvm::is_contained(strides, ShapedType::kDynamic))
1950  return failure();
1951 
1952  // Compute the stride for each dimension.
1953  SmallVector<int64_t> counts;
1954  ArrayRef<int64_t> shape = sourceType.getShape();
1955  counts.reserve(shape.size());
1956  for (int64_t v : shape) {
1957  count = count / v;
1958  counts.push_back(count);
1959  }
1960 
1961  // New attribute constructed by the sliced values.
1962  DenseElementsAttr newAttr;
1963 
1964  if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
1965  SmallVector<APInt> outValues;
1966  outValues.reserve(sourceType.getNumElements());
1967  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1968  elems.begin(), counts, offsets, sizes, strides, &outValues);
1969  newAttr = DenseElementsAttr::get(resultType, outValues);
1970  } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
1971  SmallVector<APFloat> outValues;
1972  outValues.reserve(sourceType.getNumElements());
1973  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1974  elems.begin(), counts, offsets, sizes, strides, &outValues);
1975  newAttr = DenseElementsAttr::get(resultType, outValues);
1976  }
1977 
1978  if (newAttr) {
1979  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
1980  return success();
1981  }
1982 
1983  return failure();
1984  }
1985 
1986 private:
1987  /// This additionally controls whether the fold happens or not. Users can
1988  /// impose their heuristics in the function.
1990 };
1991 
1992 } // namespace
1993 
1995  RewritePatternSet &patterns,
1996  const ControlConstantExtractSliceFusionFn &controlFn) {
1997  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
1998 }
1999 
2000 /// Return the canonical type of the result of an extract_slice op.
2002  RankedTensorType operator()(ExtractSliceOp op,
2003  ArrayRef<OpFoldResult> mixedOffsets,
2004  ArrayRef<OpFoldResult> mixedSizes,
2005  ArrayRef<OpFoldResult> mixedStrides) {
2006  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2007  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2008  mixedStrides);
2009  }
2010 };
2011 
2012 /// A canonicalizer wrapper to replace ExtractSliceOps.
2014  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2015  ExtractSliceOp newOp) {
2016  Value replacement = newOp.getResult();
2017  if (replacement.getType() != op.getType())
2018  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2019  replacement);
2020  rewriter.replaceOp(op, replacement);
2021  }
2022 };
2023 
2024 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2025  MLIRContext *context) {
2026  results.add<
2029  ExtractSliceOpCastFolder>(context);
2030 }
2031 
2032 //
2033 static LogicalResult
2034 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2035  ShapedType shapedType) {
2036  OpBuilder b(op.getContext());
2037  for (OpFoldResult ofr : op.getMixedOffsets())
2038  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2039  return failure();
2040  // Rank-reducing noops only need to inspect the leading dimensions:
2041  // llvm::zip is appropriate.
2042  auto shape = shapedType.getShape();
2043  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2044  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2045  return failure();
2046  for (OpFoldResult ofr : op.getMixedStrides())
2047  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2048  return failure();
2049  return success();
2050 }
2051 
2052 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2053 /// slice, we can return the InsertSliceOp's source directly.
2054 // TODO: This only checks the immediate producer; extend to go up the
2055 // insert/extract chain if the slices are disjoint.
2056 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2057  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2058 
2059  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2060  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2061  insertOp.isSameAs(extractOp, isSame))
2062  return insertOp.getSource();
2063 
2064  return {};
2065 }
2066 
2067 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2068  if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
2069  auto resultType = getResult().getType().cast<ShapedType>();
2070  if (resultType.hasStaticShape())
2071  return splat.resizeSplat(resultType);
2072  }
2073  if (getSourceType() == getType() &&
2075  return this->getSource();
2076  if (Value slice = foldExtractAfterInsertSlice(*this))
2077  return slice;
2078 
2079  return OpFoldResult();
2080 }
2081 
2083  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2084  auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
2085  unsigned rank = rankedTensorType.getRank();
2086  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2087  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2088  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2089  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2090  offsets, sizes, strides);
2091 }
2092 
2093 //===----------------------------------------------------------------------===//
2094 // InsertSliceOp
2095 //===----------------------------------------------------------------------===//
2096 
2097 void InsertSliceOp::getAsmResultNames(
2098  function_ref<void(Value, StringRef)> setNameFn) {
2099  setNameFn(getResult(), "inserted_slice");
2100 }
2101 
2102 // Build a InsertSliceOp with mixed static and dynamic entries.
2103 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2104  Value dest, ArrayRef<OpFoldResult> offsets,
2105  ArrayRef<OpFoldResult> sizes,
2106  ArrayRef<OpFoldResult> strides,
2107  ArrayRef<NamedAttribute> attrs) {
2108  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2109  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2110  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2111  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2112  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2113  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2114  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2115  b.getDenseI64ArrayAttr(staticSizes),
2116  b.getDenseI64ArrayAttr(staticStrides));
2117  result.addAttributes(attrs);
2118 }
2119 
2120 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2121 /// Range vector.
2122 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2123  Value dest, ArrayRef<Range> ranges,
2124  ArrayRef<NamedAttribute> attrs) {
2125  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2126  build(b, result, source, dest, offsets, sizes, strides, attrs);
2127 }
2128 
2129 // Build a InsertSliceOp with dynamic entries.
2130 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2131  Value dest, ValueRange offsets, ValueRange sizes,
2132  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2133  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2134  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2135  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2136  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2137  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2138  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2139  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2140 }
2141 
2142 /// Rank-reducing type verification for both InsertSliceOp and
2143 /// ParallelInsertSliceOp.
2145  ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
2146  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
2147  ShapedType *expectedType = nullptr) {
2148  // insert_slice is the inverse of extract_slice, use the same type
2149  // inference.
2150  RankedTensorType expected = ExtractSliceOp::inferResultType(
2151  dstType, staticOffsets, staticSizes, staticStrides);
2152  if (expectedType)
2153  *expectedType = expected;
2154  return isRankReducedType(expected, srcType);
2155 }
2156 
2157 /// Verifier for InsertSliceOp.
2159  ShapedType expectedType;
2160  SliceVerificationResult result =
2161  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2162  getStaticSizes(), getStaticStrides(), &expectedType);
2163  return produceSliceErrorMsg(result, *this, expectedType);
2164 }
2165 
2166 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2167 /// can mutate the second InsertSliceOp's destination to the first one's.
2168 ///
2169 /// Example:
2170 ///
2171 /// ```mlir
2172 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2173 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2174 /// ```
2175 ///
2176 /// folds into:
2177 ///
2178 /// ```mlir
2179 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2180 /// ```
2181 ///
2182 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2183 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2184  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2185 
2186  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2187  if (!prevInsertOp ||
2188  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2189  !prevInsertOp.isSameAs(insertOp, isSame))
2190  return failure();
2191 
2192  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2193  return success();
2194 }
2195 
2196 /// Folds round-trip extract/insert slice op pairs.
2197 /// Example:
2198 /// ```mlir
2199 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2200 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2201 /// ```
2202 /// can be folded into %val.
2203 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2204  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2205 
2206  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2207  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2208  !extractOp.isSameAs(insertOp, isSame))
2209  return nullptr;
2210 
2211  return extractOp.getSource();
2212 }
2213 
2214 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2215  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2216  getSourceType() == getType() &&
2218  return this->getSource();
2220  return getResult();
2221  if (auto result = foldInsertAfterExtractSlice(*this))
2222  return result;
2223  return OpFoldResult();
2224 }
2225 
2227  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2228  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2229  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
2230  if (getType().isDynamicDim(dim)) {
2231  reifiedReturnShapes[0][dim] =
2232  builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
2233  } else {
2234  reifiedReturnShapes[0][dim] =
2235  builder.getIndexAttr(getType().getDimSize(dim));
2236  }
2237  }
2238  return success();
2239 }
2240 
2241 namespace {
2242 /// Pattern to rewrite a insert_slice op with constant arguments.
2243 ///
2244 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2245 template <typename InsertOpTy>
2246 class InsertSliceOpConstantArgumentFolder final
2247  : public OpRewritePattern<InsertOpTy> {
2248 public:
2250 
2251  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2252  PatternRewriter &rewriter) const override {
2253  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2254  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2255  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2256 
2257  // No constant operands were folded, just return;
2258  if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) &&
2259  failed(foldDynamicIndexList(rewriter, mixedSizes)) &&
2260  failed(foldDynamicIndexList(rewriter, mixedStrides)))
2261  return failure();
2262 
2263  // Create the new op in canonical form.
2264  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2265  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2266  mixedOffsets, mixedSizes, mixedStrides);
2267  Value toInsert = insertSliceOp.getSource();
2268  if (sourceType != insertSliceOp.getSourceType()) {
2269  OpBuilder::InsertionGuard g(rewriter);
2270  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2271  // is the the insertion point is just before the ParallelCombiningOp in
2272  // the parallel case.
2273  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2274  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2275  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2276  sourceType, toInsert);
2277  }
2278  rewriter.replaceOpWithNewOp<InsertOpTy>(
2279  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2280  mixedSizes, mixedStrides);
2281  return success();
2282  }
2283 };
2284 
2285 /// Fold tensor_casts with insert_slice operations. If the source or
2286 /// destination tensor is a tensor_cast that removes static type information,
2287 /// the cast is folded into the insert_slice operation. E.g.:
2288 ///
2289 /// ```mlir
2290 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2291 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2292 /// ```
2293 ///
2294 /// folds into:
2295 ///
2296 /// ```mlir
2297 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2298 /// ```
2299 ///
2300 /// Note: When folding a cast on the destination tensor, the result of the
2301 /// insert_slice operation is casted to ensure that the type of the result did
2302 /// not change.
2303 ///
2304 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2305 template <typename InsertOpTy>
2306 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2308 
2309  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2310  PatternRewriter &rewriter) const override {
2311  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2312  return matchPattern(operand, matchConstantIndex());
2313  }))
2314  return failure();
2315 
2316  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2317  auto castOp = v.getDefiningOp<tensor::CastOp>();
2318  if (!castOp || !canFoldIntoConsumerOp(castOp))
2319  return std::nullopt;
2320  return castOp.getSource();
2321  };
2322  std::optional<Value> sourceCastSource =
2323  getSourceOfCastOp(insertSliceOp.getSource());
2324  std::optional<Value> destCastSource =
2325  getSourceOfCastOp(insertSliceOp.getDest());
2326  if (!sourceCastSource && !destCastSource)
2327  return failure();
2328 
2329  auto src =
2330  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2331  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2332  auto srcType = src.getType().template cast<ShapedType>();
2333  auto dstType = dst.getType().template cast<ShapedType>();
2334  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2335  insertSliceOp.getStaticSizes(),
2336  insertSliceOp.getStaticStrides()) !=
2338  return failure();
2339 
2340  Operation *replacement = rewriter.create<InsertOpTy>(
2341  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2342  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2343 
2344  // In the parallel case there is no result and so nothing to cast.
2345  bool isParallelInsert =
2346  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2347  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2348  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2349  insertSliceOp.getDestType(),
2350  replacement->getResult(0));
2351  }
2352  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2353  return success();
2354  }
2355 };
2356 
2357 /// If additional static type information can be deduced from a insert_slice's
2358 /// size operands, insert an explicit cast of the op's source operand. This
2359 /// enables other canonicalization patterns that are matching for tensor_cast
2360 /// ops such as `ForOpTensorCastFolder` in SCF.
2361 ///
2362 /// Example:
2363 ///
2364 /// ```mlir
2365 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2366 /// : tensor<?x?xf32> into ...
2367 /// ```
2368 ///
2369 /// folds into:
2370 ///
2371 /// ```mlir
2372 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2373 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2374 /// : tensor<64x64xf32> into ...
2375 /// ```
2376 ///
2377 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2378 template <typename InsertOpTy>
2379 struct InsertSliceOpSourceCastInserter final
2380  : public OpRewritePattern<InsertOpTy> {
2382 
2383  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2384  PatternRewriter &rewriter) const override {
2385  RankedTensorType srcType = insertSliceOp.getSourceType();
2386  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2387  return failure();
2388  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2389  srcType.getShape().end());
2390  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2391  if (std::optional<int64_t> constInt =
2392  getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
2393  newSrcShape[i] = *constInt;
2394  }
2395 
2396  RankedTensorType newSrcType =
2397  RankedTensorType::get(newSrcShape, srcType.getElementType());
2398  if (srcType == newSrcType ||
2399  !preservesStaticInformation(srcType, newSrcType) ||
2400  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2401  return failure();
2402 
2403  // newSrcType is:
2404  // 1) Different from srcType.
2405  // 2) "More static" than srcType.
2406  // 3) Cast-compatible with srcType.
2407  // Insert the cast.
2408  OpBuilder::InsertionGuard g(rewriter);
2409  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2410  // the the insertion point is just before the ParallelCombiningOp in the
2411  // parallel case.
2412  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2413  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2414  Value cast = rewriter.create<tensor::CastOp>(
2415  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2416  rewriter.replaceOpWithNewOp<InsertOpTy>(
2417  insertSliceOp, cast, insertSliceOp.getDest(),
2418  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2419  insertSliceOp.getMixedStrides());
2420  return success();
2421  }
2422 };
2423 } // namespace
2424 
2425 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2426  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2427 }
2428 
2429 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2430  MLIRContext *context) {
2431  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2432  InsertSliceOpCastFolder<InsertSliceOp>,
2433  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2434 }
2435 
2437  Location loc,
2438  Value tensor,
2439  Value dest) {
2440  auto rankedTensorType = dest.getType().cast<RankedTensorType>();
2441  unsigned rank = rankedTensorType.getRank();
2442  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2443  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2444  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2445  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2446  sizes, strides);
2447 }
2448 
2449 //===----------------------------------------------------------------------===//
2450 // PadOp
2451 //===----------------------------------------------------------------------===//
2452 
2453 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2454  setNameFn(getResult(), "padded");
2455 }
2456 
2457 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2458 // supports optional types.
2459 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2460  Type typeToInfer, Type typeToInferFrom) {}
2461 
2464  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2465  Type &typeToInfer, Type typeToInferFrom) {
2466  if (optOperand)
2467  typeToInfer = typeToInferFrom;
2468  return success();
2469 }
2470 
2472  auto sourceType = getSource().getType().cast<RankedTensorType>();
2473  auto resultType = getResult().getType().cast<RankedTensorType>();
2474  auto expectedType =
2475  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2476  if (!expectedType) {
2477  return emitError("failed to infer expectedType from sourceType ")
2478  << sourceType << ", specified resultType is " << resultType;
2479  }
2480  if (resultType.getRank() != expectedType.getRank()) {
2481  return emitError("specified type ")
2482  << resultType << " does not match the inferred type "
2483  << expectedType;
2484  }
2485  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2486  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2487  continue;
2488  if (expectedType.isDynamicDim(i))
2489  continue;
2490  return emitError("specified type ")
2491  << resultType << " does not match the inferred type "
2492  << expectedType;
2493  }
2494 
2495  return success();
2496 }
2497 
2498 LogicalResult PadOp::verifyRegions() {
2499  auto &region = getRegion();
2500  unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
2501  Block &block = region.front();
2502  if (block.getNumArguments() != rank)
2503  return emitError("expected the block to have ") << rank << " arguments";
2504 
2505  // Note: the number and type of yield values are checked in the YieldOp.
2506  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2507  if (!en.value().isIndex())
2508  return emitOpError("expected block argument ")
2509  << (en.index() + 1) << " to be an index";
2510  }
2511 
2512  // Ensure that the region yields an element of the right type.
2513  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2514  if (yieldOp.getValue().getType() !=
2515  getType().cast<ShapedType>().getElementType())
2516  return emitOpError("expected yield type to match shape element type");
2517 
2518  return success();
2519 }
2520 
2521 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2522  ArrayRef<int64_t> staticLow,
2523  ArrayRef<int64_t> staticHigh,
2524  ArrayRef<int64_t> resultShape) {
2525  unsigned rank = sourceType.getRank();
2526  if (staticLow.size() != rank)
2527  return RankedTensorType();
2528  if (staticHigh.size() != rank)
2529  return RankedTensorType();
2530  if (!(resultShape.empty() || resultShape.size() == rank))
2531  return RankedTensorType();
2532 
2533  SmallVector<int64_t, 4> inferredShape;
2534  for (auto i : llvm::seq<unsigned>(0, rank)) {
2535  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2536  staticHigh[i] == ShapedType::kDynamic) {
2537  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2538  : resultShape[i]);
2539  } else {
2540  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2541  assert((resultShape.empty() || size == resultShape[i] ||
2542  resultShape[i] == ShapedType::kDynamic) &&
2543  "mismatch between inferred shape and result shape");
2544  inferredShape.push_back(size);
2545  }
2546  }
2547 
2548  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2549 }
2550 
2551 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2552  Value source, ArrayRef<int64_t> staticLow,
2553  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2554  bool nofold, ArrayRef<NamedAttribute> attrs) {
2555  auto sourceType = source.getType().cast<RankedTensorType>();
2556  if (!resultType)
2557  resultType = inferResultType(sourceType, staticLow, staticHigh);
2558  build(b, result, resultType, source, low, high,
2559  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2560  nofold ? b.getUnitAttr() : UnitAttr());
2561  result.addAttributes(attrs);
2562 }
2563 
2564 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2565  Value source, ValueRange low, ValueRange high, bool nofold,
2566  ArrayRef<NamedAttribute> attrs) {
2567  auto sourceType = source.getType().cast<RankedTensorType>();
2568  unsigned rank = sourceType.getRank();
2569  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2570  build(b, result, resultType, source, staticVector, staticVector, low, high,
2571  nofold, attrs);
2572 }
2573 
2574 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2575  Value source, ArrayRef<OpFoldResult> low,
2576  ArrayRef<OpFoldResult> high, bool nofold,
2577  ArrayRef<NamedAttribute> attrs) {
2578  auto sourceType = source.getType().cast<RankedTensorType>();
2579  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2580  SmallVector<int64_t, 4> staticLow, staticHigh;
2581  // staticLow and staticHigh have full information of the padding config.
2582  // This will grow staticLow and staticHigh with 1 value. If the config is
2583  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
2584  // value as well.
2585  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
2586  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
2587  if (!resultType) {
2588  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2589  }
2590  assert(resultType.isa<RankedTensorType>());
2591  build(b, result, resultType, source, dynamicLow, dynamicHigh,
2592  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2593  nofold ? b.getUnitAttr() : UnitAttr());
2594  result.addAttributes(attrs);
2595 }
2596 
2597 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2598  Value source, ArrayRef<OpFoldResult> low,
2599  ArrayRef<OpFoldResult> high, Value constantPadValue,
2600  bool nofold, ArrayRef<NamedAttribute> attrs) {
2601  build(b, result, resultType, source, low, high, nofold, attrs);
2602 
2603  // Add a region and a block to yield the pad value.
2604  Region *region = result.regions[0].get();
2605  int sourceRank = source.getType().cast<RankedTensorType>().getRank();
2606  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
2607  SmallVector<Location> blockArgLocs(sourceRank, result.location);
2608 
2609  // `builder.createBlock` changes the insertion point within the block. Create
2610  // a guard to reset the insertion point of the builder after it is destroyed.
2611  OpBuilder::InsertionGuard guard(b);
2612  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
2613  b.create<tensor::YieldOp>(result.location, constantPadValue);
2614 }
2615 
2616 llvm::SmallBitVector PadOp::getPaddedDims() {
2617  llvm::SmallBitVector paddedDims(getSourceType().getRank());
2618  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
2619  for (const auto &en : enumerate(paddingWidths))
2620  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
2621  paddedDims.set(en.index());
2622  };
2623  extractPaddedDims(getMixedLowPad());
2624  extractPaddedDims(getMixedHighPad());
2625  return paddedDims;
2626 }
2627 
2628 namespace {
2629 // Folds tensor.pad when padding is static zeros and the attribute
2630 // doesn't request otherwise.
2631 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
2633 
2634  LogicalResult matchAndRewrite(PadOp padTensorOp,
2635  PatternRewriter &rewriter) const override {
2636  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2637  return failure();
2638  if (padTensorOp.getNofold())
2639  return failure();
2640  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2641  padTensorOp, padTensorOp.getResult().getType(),
2642  padTensorOp.getSource());
2643  return success();
2644  }
2645 };
2646 
2647 // Fold CastOp into PadOp when adding static information.
2648 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
2650 
2651  LogicalResult matchAndRewrite(PadOp padTensorOp,
2652  PatternRewriter &rewriter) const override {
2653  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2654  if (!tensor::canFoldIntoConsumerOp(castOp))
2655  return failure();
2656 
2657  auto newResultType = PadOp::inferResultType(
2658  castOp.getSource().getType().cast<RankedTensorType>(),
2659  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2660  padTensorOp.getResultType().getShape());
2661 
2662  if (newResultType == padTensorOp.getResultType()) {
2663  rewriter.updateRootInPlace(padTensorOp, [&]() {
2664  padTensorOp.getSourceMutable().assign(castOp.getSource());
2665  });
2666  } else {
2667  auto newOp = rewriter.create<PadOp>(
2668  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2669  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2670  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2671  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2672  IRMapping mapper;
2673  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2674 
2675  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2676  padTensorOp, padTensorOp.getResultType(), newOp);
2677  }
2678  return success();
2679  }
2680 };
2681 
2682 // Fold CastOp using the result of PadOp back into the latter if it adds
2683 // static information.
2684 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2686 
2687  LogicalResult matchAndRewrite(PadOp padTensorOp,
2688  PatternRewriter &rewriter) const override {
2689  if (!padTensorOp.getResult().hasOneUse())
2690  return failure();
2691  auto tensorCastOp =
2692  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2693  if (!tensorCastOp)
2694  return failure();
2695  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2696  tensorCastOp.getDest().getType()))
2697  return failure();
2698 
2699  auto replacementOp = rewriter.create<PadOp>(
2700  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2701  padTensorOp.getSource(), padTensorOp.getStaticLow(),
2702  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2703  padTensorOp.getHigh(), padTensorOp.getNofold(),
2704  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2705  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2706 
2707  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2708  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2709  return success();
2710  }
2711 };
2712 
2713 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2714 /// different dimensions. The pattern applies if the following preconditions
2715 /// hold:
2716 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2717 /// 2) the tensor::ExtractSliceOps have only unit-strides,
2718 /// 3) the tensor::PadOps perform only high-padding,
2719 /// 4) the tensor::PadOps have the same constant padding value,
2720 /// 5) the tensor::PadOps do not have common padding dimensions,
2721 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2722 /// zero-offset for every dimension.
2723 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
2724 /// the
2725 /// padded source dimensions.
2726 ///
2727 /// Example:
2728 ///
2729 /// ```mlir
2730 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2731 /// : tensor<64x64xf32> to tensor<?x64xf32>
2732 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2733 /// } : tensor<?x64xf32> to tensor<8x64xf32>
2734 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2735 /// : tensor<8x64xf32> to tensor<8x?xf32>
2736 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2737 /// } : tensor<8x?xf32> to tensor<8x4xf32>
2738 /// ```
2739 ///
2740 /// folds into:
2741 ///
2742 /// ```mlir
2743 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2744 /// : tensor<64x64xf32> to tensor<?x?xf32>
2745 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2746 /// } : tensor<?x?xf32> to tensor<8x4xf32>
2747 /// ```
2748 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2750 
2751  LogicalResult matchAndRewrite(PadOp padOp,
2752  PatternRewriter &rewriter) const override {
2753  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2754  if (!innerSliceOp)
2755  return failure();
2756  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2757  if (!outerPadOp || outerPadOp.getNofold())
2758  return failure();
2759  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2760  if (!outerSliceOp)
2761  return failure();
2762 
2763  // 1) Fail if the chain is rank-reducing.
2764  int64_t rank = padOp.getSourceType().getRank();
2765  if (outerSliceOp.getSourceType().getRank() != rank) {
2766  return rewriter.notifyMatchFailure(padOp,
2767  "cannot fold rank-reducing chain");
2768  }
2769 
2770  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2771  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2772  return rewriter.notifyMatchFailure(
2773  padOp, "cannot fold non-unit stride ExtractSliceOps");
2774  }
2775 
2776  // 3) Fail if the tensor::PadOps have non-zero low padding.
2777  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2778  return rewriter.notifyMatchFailure(padOp,
2779  "cannot fold PadOps with low padding");
2780  }
2781 
2782  // 4) Fail if the tensor::PadOps padding values do not match.
2783  Attribute innerAttr, outerAttr;
2784  Value innerValue = padOp.getConstantPaddingValue();
2785  Value outerValue = outerPadOp.getConstantPaddingValue();
2786  if (!innerValue || !outerValue ||
2787  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2788  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2789  innerAttr != outerAttr) {
2790  return rewriter.notifyMatchFailure(
2791  padOp, "cannot fold PadOps with different padding values");
2792  }
2793 
2794  // 5) Fail if a dimension is padded by both tensor::PadOps.
2795  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2796  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2797  if (innerDims.anyCommon(outerDims)) {
2798  return rewriter.notifyMatchFailure(
2799  padOp, "cannot fold PadOps with common padding dimensions");
2800  }
2801 
2802  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2803  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2804  // for every dimension, and use the offset the other pair. Fail if no
2805  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2806  // exists.
2807  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2808  for (auto en : enumerate(newOffsets)) {
2809  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2810  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2811  if (!innerDims.test(en.index()) &&
2812  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2813  en.value() = outerOffset;
2814  continue;
2815  }
2816  if (!outerDims.test(en.index()) &&
2817  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2818  en.value() = innerOffset;
2819  continue;
2820  }
2821  return rewriter.notifyMatchFailure(
2822  padOp, "cannot find zero-offset and zero-padding pair");
2823  }
2824 
2825  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
2826  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
2827  // outer tensor::PadOp and fail if the size of the inner
2828  // tensor::ExtractSliceOp does not match the size of the padded dimension.
2829  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
2830  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2831  for (auto en : enumerate(newSizes)) {
2832  if (!outerDims.test(en.index()))
2833  continue;
2834  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2835  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2836  assert(!ShapedType::isDynamic(sourceSize) &&
2837  "expected padded dimension to have a static size");
2838  if (getConstantIntValue(sliceSize) != sourceSize) {
2839  return rewriter.notifyMatchFailure(
2840  padOp, "cannot fold since the inner ExtractSliceOp size does not "
2841  "match the size of the outer padding");
2842  }
2843  en.value() = outerSliceOp.getMixedSizes()[en.index()];
2844  }
2845 
2846  // Combine the high paddings of the two tensor::PadOps.
2847  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2848  for (auto en : enumerate(newHighPad)) {
2849  if (innerDims.test(en.index()))
2850  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2851  if (outerDims.test(en.index()))
2852  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2853  }
2854 
2855  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
2856  // the two paddings in one step.
2857  auto newSliceOp = rewriter.create<ExtractSliceOp>(
2858  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2859  innerSliceOp.getMixedStrides());
2860  auto newPadOp = rewriter.create<PadOp>(
2861  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2862  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
2863  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
2864  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2865  newPadOp.getRegion().begin());
2866  rewriter.replaceOp(padOp, newPadOp.getResult());
2867  return success();
2868  }
2869 };
2870 
2871 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
2873 
2874  LogicalResult matchAndRewrite(PadOp padTensorOp,
2875  PatternRewriter &rewriter) const override {
2876  Value input = padTensorOp.getSource();
2877  if (!input.getType().isa<RankedTensorType>())
2878  return failure();
2879  auto inputDims = input.getType().cast<RankedTensorType>().getShape();
2880  auto inputRank = inputDims.size();
2881 
2882  auto oldResultType =
2883  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
2884  if (!oldResultType)
2885  return failure();
2886 
2887  auto outputDims = oldResultType.getShape();
2888 
2889  // Extract the static info from the high and low operands.
2890  SmallVector<int64_t> constOperandsLow;
2891  for (auto operand : padTensorOp.getLow()) {
2892  APSInt intOp;
2893  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
2894  constOperandsLow.push_back(ShapedType::kDynamic);
2895  continue;
2896  }
2897  constOperandsLow.push_back(intOp.getExtValue());
2898  }
2899  SmallVector<int64_t> constOperandsHigh;
2900  for (auto operand : padTensorOp.getHigh()) {
2901  APSInt intOp;
2902  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
2903  constOperandsHigh.push_back(ShapedType::kDynamic);
2904  continue;
2905  }
2906  constOperandsHigh.push_back(intOp.getExtValue());
2907  }
2908 
2909  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
2910  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
2911 
2912  // Verify the op is well-formed.
2913  if (inputDims.size() != outputDims.size() ||
2914  inputDims.size() != constLow.size() ||
2915  inputDims.size() != constHigh.size())
2916  return failure();
2917 
2918  auto lowCount = 0;
2919  auto highCount = 0;
2920  for (size_t i = 0; i < inputRank; i++) {
2921  if (constLow[i] == ShapedType::kDynamic)
2922  constLow[i] = constOperandsLow[lowCount++];
2923  if (constHigh[i] == ShapedType::kDynamic)
2924  constHigh[i] = constOperandsHigh[highCount++];
2925  }
2926 
2927  auto staticLow = ArrayRef<int64_t>(constLow);
2928  auto staticHigh = ArrayRef<int64_t>(constHigh);
2929 
2930  // Calculate the output sizes with the static information.
2931  SmallVector<int64_t> newOutDims;
2932  for (size_t i = 0; i < inputRank; i++) {
2933  if (outputDims[i] == ShapedType::kDynamic) {
2934  newOutDims.push_back(
2935  (staticLow[i] == ShapedType::kDynamic ||
2936  staticHigh[i] == ShapedType::kDynamic ||
2937  inputDims[i] == ShapedType::kDynamic
2938  ? ShapedType::kDynamic
2939  : inputDims[i] + staticLow[i] + staticHigh[i]));
2940  } else {
2941  newOutDims.push_back(outputDims[i]);
2942  }
2943  }
2944 
2945  if (SmallVector<int64_t>(outputDims) == newOutDims ||
2946  llvm::all_of(newOutDims,
2947  [&](int64_t x) { return x == ShapedType::kDynamic; }))
2948  return failure();
2949 
2950  // Rewrite the op using the new static type.
2951  auto newResultType = RankedTensorType::get(
2952  newOutDims, padTensorOp.getType().getElementType());
2953  auto newOp = rewriter.create<PadOp>(
2954  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
2955  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2956  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2957 
2958  IRMapping mapper;
2959  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2960  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
2961  newOp);
2962 
2963  return success();
2964  }
2965 };
2966 
2967 } // namespace
2968 
2969 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2970  MLIRContext *context) {
2971  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2972  FoldOrthogonalPaddings, FoldStaticPadding>(context);
2973 }
2974 
2975 /// Return the padding value of the PadOp if it constant. In this context,
2976 /// "constant" means an actual constant or "defined outside of the block".
2977 ///
2978 /// Values are considered constant in three cases:
2979 /// - A ConstantLike value.
2980 /// - A basic block argument from a different block.
2981 /// - A value defined outside of the block.
2982 ///
2983 /// If the padding value is not constant, an empty Value is returned.
2984 Value PadOp::getConstantPaddingValue() {
2985  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2986  if (!yieldOp)
2987  return {};
2988  Value padValue = yieldOp.getValue();
2989  // Check if yield value is a constant.
2990  if (matchPattern(padValue, m_Constant()))
2991  return padValue;
2992  // Check if yield value is defined inside the PadOp block.
2993  if (padValue.getParentBlock() == &getRegion().front())
2994  return {};
2995  // Else: Yield value defined outside of the PadOp block.
2996  return padValue;
2997 }
2998 
2999 OpFoldResult PadOp::fold(FoldAdaptor) {
3000  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3001  !getNofold())
3002  return getSource();
3003  return {};
3004 }
3005 
3006 //===----------------------------------------------------------------------===//
3007 // ParallelInsertSliceOp
3008 //===----------------------------------------------------------------------===//
3009 
3010 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3011  ParallelCombiningOpInterface parallelCombiningParent =
3012  getParallelCombiningParent();
3013  for (const auto &it :
3014  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3015  Operation &nextOp = it.value();
3016  if (&nextOp == getOperation())
3017  return parallelCombiningParent.getParentResult(it.index());
3018  }
3019  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3020 }
3021 
3022 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3023 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3024  Value source, Value dest,
3025  ArrayRef<OpFoldResult> offsets,
3026  ArrayRef<OpFoldResult> sizes,
3027  ArrayRef<OpFoldResult> strides,
3028  ArrayRef<NamedAttribute> attrs) {
3029  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3030  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3031  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3032  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3033  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3034  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3035  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3036  b.getDenseI64ArrayAttr(staticSizes),
3037  b.getDenseI64ArrayAttr(staticStrides));
3038  result.addAttributes(attrs);
3039 }
3040 
3041 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3042 /// packed into a Range vector.
3043 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3044  Value source, Value dest,
3045  ArrayRef<Range> ranges,
3046  ArrayRef<NamedAttribute> attrs) {
3047  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3048  build(b, result, source, dest, offsets, sizes, strides, attrs);
3049 }
3050 
3051 // Build a ParallelInsertSliceOp with dynamic entries.
3052 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3053  Value source, Value dest, ValueRange offsets,
3054  ValueRange sizes, ValueRange strides,
3055  ArrayRef<NamedAttribute> attrs) {
3056  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3057  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3058  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3059  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3060  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3061  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3062  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3063 }
3064 
3066  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3067  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3068  << *(getOperation()->getParentOp());
3069 
3070  ShapedType expectedType;
3071  SliceVerificationResult result =
3072  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3073  getStaticSizes(), getStaticStrides(), &expectedType);
3074  return produceSliceErrorMsg(result, *this, expectedType);
3075 }
3076 
3077 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3078  RewritePatternSet &results, MLIRContext *context) {
3079  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3080  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3081  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3082 }
3083 
3084 //===----------------------------------------------------------------------===//
3085 // ScatterOp
3086 //===----------------------------------------------------------------------===//
3087 
3088 void ScatterOp::getAsmResultNames(
3089  function_ref<void(Value, StringRef)> setNameFn) {
3090  setNameFn(getResult(), "scatter");
3091 }
3092 
3094  int64_t destRank = getDestType().getRank();
3095  ArrayRef<int64_t> scatterDims = getScatterDims();
3096  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3097  "scatter", "dest")))
3098  return failure();
3099 
3100  if (!getUnique())
3101  return emitOpError("requires 'unique' attribute to be set");
3102  // TODO: we could also check statically that there are fewer leading index
3103  // tensor dims than the dest dims. If this is not the case, the unique
3104  // attribute cannot be true.
3105 
3106  // Use the GatherOp::inferResultType on the `dest` type and verify the
3107  // expected type matches the source type.
3108  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3109  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3110  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3111  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3112  if (getSourceType() != expectedSourceType &&
3113  getSourceType() != expectedRankReducedSourceType) {
3114  return emitOpError("source type "
3115  "mismatch: "
3116  "expected ")
3117  << expectedSourceType << " or its rank-reduced variant "
3118  << expectedRankReducedSourceType << " (got: " << getSourceType()
3119  << ")";
3120  }
3121 
3122  return success();
3123 }
3124 
3125 //===----------------------------------------------------------------------===//
3126 // SplatOp
3127 //===----------------------------------------------------------------------===//
3128 
3129 void SplatOp::getAsmResultNames(
3130  function_ref<void(Value, StringRef)> setNameFn) {
3131  setNameFn(getResult(), "splat");
3132 }
3133 
3134 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3135  auto constOperand = adaptor.getInput();
3136  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3137  return {};
3138 
3139  // SplatElementsAttr::get treats single value for second arg as being a
3140  // splat.
3141  return SplatElementsAttr::get(getType(), {constOperand});
3142 }
3143 
3144 //===----------------------------------------------------------------------===//
3145 // PackOp/UnPackOp Common
3146 //===----------------------------------------------------------------------===//
3147 
3148 namespace {
3149 
3150 /// Packing one-dimensional tensor can be expressed as an expand shape op.
3151 struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
3153 
3154  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
3155  Type newOperandType, ArrayAttr reassociation) const {
3156  if (operand.getType() == newOperandType)
3157  return operand;
3158  return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3159  reassociation);
3160  }
3161 
3162  LogicalResult matchAndRewrite(PackOp packOp,
3163  PatternRewriter &rewriter) const override {
3164  RankedTensorType sourceType = packOp.getSourceType();
3165  RankedTensorType destType = packOp.getDestType();
3166  if (sourceType.getRank() != 1 || packOp.getPaddingValue())
3167  return failure();
3168  auto reassociation =
3169  getReassociationIndicesForReshape(sourceType, destType);
3170  if (!reassociation)
3171  return failure();
3172  Value expanded = insertExpand(
3173  rewriter, packOp.getLoc(), packOp.getSource(), destType,
3174  getReassociationIndicesAttribute(rewriter, *reassociation));
3175  rewriter.replaceOp(packOp, expanded);
3176  return success();
3177  }
3178 };
3179 
3180 } // namespace
3181 
3183  patterns.add<SimplifyPackToExandShape>(patterns.getContext());
3184 }
3185 
3186 template <typename OpTy>
3187 static LogicalResult
3189  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3190  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3191  "applies to only pack or unpack operations");
3192  int64_t destRank = op.getDestRank();
3193  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3194  ShapedType resultType = op.getResult().getType().template cast<ShapedType>();
3195  for (auto dim : llvm::seq<int64_t>(0, destRank)) {
3196  if (resultType.isDynamicDim(dim)) {
3197  reifiedReturnShapes[0][dim] =
3198  builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
3199  } else {
3200  reifiedReturnShapes[0][dim] =
3201  builder.getIndexAttr(resultType.getDimSize(dim));
3202  }
3203  }
3204  return success();
3205 }
3206 
3207 template <typename OpTy>
3209  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3210  "applies to only pack or unpack operations");
3211  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3212  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3213  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3214  assert(tiles.size() == dimsToTile.size() &&
3215  "tiles must match indices of dimension to block");
3216  // bind the dimension `i` with the tile factor.
3217  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3218  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3219  return dimAndTileMapping;
3220 }
3221 
3222 template <typename OpTy>
3224  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3225  "applies to only pack or unpack operations");
3226  Builder builder(op);
3227  SmallVector<OpFoldResult> mixedInnerTiles;
3228  unsigned dynamicValIndex = 0;
3229  for (int64_t staticTile : op.getStaticInnerTiles()) {
3230  if (!ShapedType::isDynamic(staticTile))
3231  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3232  else
3233  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3234  }
3235  return mixedInnerTiles;
3236 }
3237 
3238 template <typename OpTy>
3240  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3241  "applies to only pack or unpack operations");
3242  SmallVector<Value> dynamicTiles;
3243  SmallVector<int64_t> staticTiles;
3244  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3245  return staticTiles;
3246 }
3247 
3248 /// Returns true if `dimsPos` is invalid. It is invalid when:
3249 /// a) It contains duplicate.
3250 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3251 /// c) The number of elements in `dimsPos` is > than `rank`.
3253  size_t rank) {
3254  size_t dimsPosSize = dimsPos.size();
3255  if (dimsPosSize > rank)
3256  return true;
3257  DenseSet<int64_t> uniqued;
3258  for (int64_t dim : dimsPos)
3259  uniqued.insert(dim);
3260  if (dimsPosSize != uniqued.size())
3261  return true;
3262  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3263  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3264  });
3265 }
3266 
3267 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3268 /// of the `limitShape`.
3269 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3270  ArrayRef<int64_t> limitShape) {
3271  assert(
3272  sourceShape.size() == limitShape.size() &&
3273  "expected source shape rank, and limit of the shape to have same rank");
3274  return llvm::all_of(
3275  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3276  int64_t sourceExtent = std::get<0>(it);
3277  int64_t limit = std::get<1>(it);
3278  return ShapedType::isDynamic(sourceExtent) ||
3279  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3280  });
3281 }
3282 
3283 template <typename OpTy>
3285  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3286  "applies to only pack or unpack operations");
3287  Operation *op = packOrUnPack.getOperation();
3288 
3289  // Return true if we have a zero-value tile.
3290  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3291  return llvm::any_of(
3292  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3293  };
3294 
3295  // Verify tiles. Do not allow zero tiles.
3296  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3297  if (hasZeros(mixedTiles))
3298  return op->emitError("invalid zero tile factor");
3299 
3300  // Verify inner_dims_pos and outer_dims_perm.
3301  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
3302  ? packOrUnPack.getSourceType()
3303  : packOrUnPack.getDestType();
3304  size_t unpackedRank = unpackedType.getRank();
3305  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3306  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3307  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3308  return op->emitError("invalid inner_dims_pos vector");
3309  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3310  return op->emitError("invalid outer_dims_perm vector");
3311  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3312  return op->emitError("outer_dims_perm must be a permutation or empty");
3313 
3314  // Tiling factors must be less than or equal to the input rank for pack (or
3315  // output rank for unpack), and must match the number of `inner_dims_pos`.
3316  if (mixedTiles.size() > unpackedRank) {
3317  return op->emitError("tiling factors must be less than or equal to the "
3318  "input rank for pack or output rank for unpack");
3319  }
3320  if (mixedTiles.size() != innerDimsPos.size()) {
3321  return op->emitError(
3322  "tiling factors must equal the number of dimensions to tile");
3323  }
3324 
3325  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3326  ? packOrUnPack.getDestType()
3327  : packOrUnPack.getSourceType();
3328  size_t packedRank = packedType.getRank();
3329  // Require output rank to match input rank + number of blocking factors.
3330  if (unpackedRank + mixedTiles.size() != packedRank) {
3331  return op->emitError(
3332  "packed rank must equal unpacked rank + tiling factors");
3333  }
3334 
3335  // Verify result shape is greater than the minimum expected
3336  // by the pack operation, and that the output shape
3337  // represents full tiles.
3338  ShapedType expectedPackedType = PackOp::inferPackedType(
3339  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3340  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3341  return op->emitError("the shape of output is not large enough to hold the "
3342  "packed data. Expected at least ")
3343  << expectedPackedType << ", got " << packedType;
3344  }
3345  if (!llvm::all_of(
3346  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3347  mixedTiles),
3348  [](std::tuple<int64_t, OpFoldResult> it) {
3349  std::optional<int64_t> constTileSize =
3350  getConstantIntValue(std::get<1>(it));
3351  int64_t shape = std::get<0>(it);
3352  if (!constTileSize) {
3353  // If specified tile size is dynamic, output shape should
3354  // be dynamic too.
3355  return ShapedType::isDynamic(shape);
3356  }
3357  if (ShapedType::isDynamic(shape)) {
3358  // For the shape being dynamic when tile size is
3359  // specified, return true. In canonical form a constant
3360  // tile size should lead to constant shape of the tiled
3361  // dimension, but not needed for verification.
3362  return true;
3363  }
3364  return shape == constTileSize.value();
3365  })) {
3366  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3367  "tiled dimension in the packed type");
3368  }
3369  return success();
3370 }
3371 
3372 namespace {
3373 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3374 /// various permutations to the op.
3375 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3376 // these. These may or may not become true foldings / canonicalizations
3377 // depending on how aggressive we want to be in automatically folding
3378 // transposes.
3379 struct PackOrUnPackTransposeResult {
3380  SmallVector<int64_t> innerDimsPos;
3381  SmallVector<OpFoldResult> innerTiles;
3382  SmallVector<int64_t> outerDimsPerm;
3383 };
3384 } // namespace
3385 
3386 template <typename OpTy>
3387 static PackOrUnPackTransposeResult
3389  ArrayRef<int64_t> innerPermutation,
3390  ArrayRef<int64_t> outerPermutation) {
3391  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3392  "applies to only pack or unpack operations");
3393  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3394  "some permutation must be non-empty");
3395  PackOrUnPackTransposeResult metadata;
3396  metadata.innerDimsPos =
3397  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3398  metadata.innerTiles =
3399  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3400  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3401  ? packOrUnPackOp.getSourceRank()
3402  : packOrUnPackOp.getDestRank();
3403  metadata.outerDimsPerm =
3404  packOrUnPackOp.getOuterDimsPerm().empty()
3405  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3406  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3407  if (!innerPermutation.empty()) {
3408  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3409  isPermutationVector(innerPermutation) &&
3410  "invalid inner permutation");
3411  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3412  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3413  }
3414  if (!outerPermutation.empty()) {
3415  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3416  isPermutationVector(outerPermutation) &&
3417  "invalid outer permutation");
3418  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3419  }
3420  return metadata;
3421 }
3422 
3423 //===----------------------------------------------------------------------===//
3424 // PackOp
3425 //===----------------------------------------------------------------------===//
3426 
3427 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3428  setNameFn(getResult(), "pack");
3429 }
3430 
3431 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3432  Value dest, ArrayRef<int64_t> innerDimsPos,
3433  ArrayRef<OpFoldResult> innerTiles,
3434  std::optional<Value> paddingValue,
3435  ArrayRef<int64_t> outerDimsPerm) {
3436  assert(innerDimsPos.size() == innerTiles.size() &&
3437  "number of tile sizes specified must match the specified number of "
3438  "original dimensions to be tiled");
3439  SmallVector<int64_t> staticTileSizes;
3440  SmallVector<Value> dynamicTileSizes;
3441  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3442  build(builder, state, dest.getType(), source, dest,
3443  paddingValue ? *paddingValue : nullptr,
3444  outerDimsPerm.empty() ? nullptr
3445  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3446  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3447  builder.getDenseI64ArrayAttr(staticTileSizes));
3448 }
3449 
3452  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3453  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3454 }
3455 
3456 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3457  return getDimAndTileMappingImpl(*this);
3458 }
3459 
3460 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3461  return getMixedTilesImpl(*this);
3462 }
3463 
3464 SmallVector<int64_t> PackOp::getStaticTiles() {
3465  return getStaticTilesImpl(*this);
3466 }
3467 
3468 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3469  ArrayRef<int64_t> innerDimsPos,
3470  ArrayRef<OpFoldResult> innerTiles) {
3471  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3472  if (ShapedType::isDynamic(inputShape[pos]))
3473  continue;
3474  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3475  if (!constantTile)
3476  continue;
3477  if (inputShape[pos] % (*constantTile) != 0)
3478  return true;
3479  }
3480  return false;
3481 }
3482 
3485  return failure();
3486 
3487  // Verify padding value, and bail out if the tile does not divide the
3488  // dimension fully. In the case of dynamic tile factors or dimensions, having
3489  // a partial tile is undefined behavior.
3490  auto paddingValue = getPaddingValue();
3491  if (paddingValue &&
3492  paddingValue.getType() != getSourceType().getElementType()) {
3493  return emitOpError("expected padding_value has ")
3494  << getSourceType().getElementType()
3495  << " but got: " << paddingValue.getType();
3496  }
3497 
3498  if (!paddingValue &&
3499  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3500  getMixedTiles())) {
3501  return emitOpError("invalid tile factor provided. Only full tiles are "
3502  "supported when padding_value is not set");
3503  }
3504  return success();
3505 }
3506 
3507 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3508 /// Value's to kDynamic, even if they are arith.constant values.
3509 static SmallVector<int64_t>
3511  SmallVector<int64_t> result;
3512  for (auto o : ofrs) {
3513  // Have to do this first, as getConstantIntValue special-cases constants.
3514  if (o.dyn_cast<Value>())
3515  result.push_back(ShapedType::kDynamic);
3516  else
3517  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3518  }
3519  return result;
3520 }
3521 
3522 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3523 /// the packed type. Having a shared helper helps implement these two methods in
3524 /// a way that ensures that they agree on which dimensions are dynamic.
3526  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3527  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3528  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3529  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3530  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3531  continue;
3532  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3533  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3534  continue;
3535  }
3536  resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
3537  innerTileSizes[tiledDim.index()]);
3538  }
3539 
3540  // Swap tile loops if outer_dims_perm is available.
3541  if (!outerDimsPerm.empty())
3542  applyPermutationToVector(resultShape, outerDimsPerm);
3543 
3544  // Append the inner tile dimensions.
3545  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3546  return resultShape;
3547 }
3548 
3549 SmallVector<OpFoldResult> PackOp::getResultShape(
3550  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
3551  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
3552  ArrayRef<int64_t> outerDimsPerm) {
3553  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
3554 
3555  AffineExpr s0, s1;
3556  bindSymbols(builder.getContext(), s0, s1);
3557  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
3558  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3559  resultDims[tiledDim.value()] = makeComposedFoldedAffineApply(
3560  builder, loc, ceilDivExpr,
3561  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3562  }
3563  if (!outerDimsPerm.empty())
3564  applyPermutationToVector(resultDims, outerDimsPerm);
3565  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3566 
3567  SmallVector<int64_t> resultTypeShape =
3569  asShapeWithAnyValueAsDynamic(innerTileSizes),
3570  innerDimsPos, outerDimsPerm);
3571 
3572  // Fix-up `resultDims` to ensure that they are Value's if and only if the
3573  // result type shape says it's a dynamic dim. This is needed as callers may
3574  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
3575  // dynamic dims returned by that.
3576  for (unsigned i = 0; i < resultDims.size(); ++i) {
3577  if (!ShapedType::isDynamic(resultTypeShape[i]))
3578  continue;
3579  resultDims[i] =
3580  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
3581  }
3582 
3583  return resultDims;
3584 }
3585 
3586 /// Get the expected packed type based on source type, tile factors, position of
3587 /// the inner tiles and permutation of the outer tiled loop.
3588 ShapedType PackOp::inferPackedType(ShapedType sourceType,
3589  ArrayRef<int64_t> innerTileSizes,
3590  ArrayRef<int64_t> innerDimsPos,
3591  ArrayRef<int64_t> outerDimsPerm) {
3593  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3594  return RankedTensorType::get(resultShape, sourceType.getElementType());
3595 }
3596 
3597 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
3598  ArrayRef<OpFoldResult> innerTileSizes,
3599  ArrayRef<int64_t> innerDimsPos,
3600  ArrayRef<int64_t> outerDimsPerm) {
3601  AffineExpr dim0, dim1;
3602  bindDims(b.getContext(), dim0, dim1);
3603  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
3604  return makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), {v1, v2});
3605  };
3606 
3607  SmallVector<OpFoldResult> mixedSizes;
3608  for (auto [index, value] :
3609  llvm::enumerate(source.getType().cast<RankedTensorType>().getShape())) {
3610  if (ShapedType::isDynamic(value))
3611  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
3612  else
3613  mixedSizes.push_back(b.getIndexAttr(value));
3614  }
3615  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3616  int64_t dimPos = std::get<0>(it);
3617  OpFoldResult tileSize = std::get<1>(it);
3618  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
3619  }
3620  if (!outerDimsPerm.empty())
3621  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3622 
3623  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3624  auto elemType = source.getType().cast<ShapedType>().getElementType();
3625  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3626 }
3627 
3628 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
3629  ArrayRef<int64_t> innerPermutation,
3630  ArrayRef<int64_t> outerPermutation) {
3631  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3632  *this, innerPermutation, outerPermutation);
3633  Value transposedDest =
3634  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3635  metadata.innerDimsPos, metadata.outerDimsPerm);
3636  return b.create<PackOp>(loc, getSource(), transposedDest,
3637  metadata.innerDimsPos, metadata.innerTiles,
3638  getPaddingValue(), metadata.outerDimsPerm);
3639 }
3640 
3641 /// Returns true if the tiles and the tiled dims are constant.
3642 template <typename OpTy>
3644  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3645  "applies to only pack or unpack operations");
3646  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3647  ? op.getDestType()
3648  : op.getSourceType();
3649  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
3650  for (auto [dimDest, tile] : llvm::zip(
3651  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3652  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
3653  if (!constTileSize || ShapedType::isDynamic(dimDest))
3654  return false;
3655  }
3656  return true;
3657 }
3658 
3659 Speculation::Speculatability PackOp::getSpeculatability() {
3660  if (auto paddingValue = getPaddingValue())
3662 
3663  // The verifier rejects already operations if we can statically prove that the
3664  // sizes of the tiles do not divide perfectly the dimension; thus, check only
3665  // to have constant tiles and tiled inner dimensions.
3666  if (!areTilesAndTiledDimsAllConstant(*this))
3668 
3670 }
3671 
3672 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
3673 // dimensions for pack and unpack.
3674 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
3675  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3676  return false;
3677  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3678 }
3679 
3680 // Return true if pack and unpack have the same tiles.
3681 // Same SSA values or same integer constants.
3682 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
3683  auto packTiles = packOp.getMixedTiles();
3684  auto unPackTiles = unPackOp.getMixedTiles();
3685  if (packTiles.size() != unPackTiles.size())
3686  return false;
3687  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
3688  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
3689  return false;
3690  }
3691  return true;
3692 }
3693 
3694 /// Fold an unpack(pack(x)) to x.
3695 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
3696  UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3697  if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3698  return failure();
3699  if (packOp.getPaddingValue() ||
3700  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
3701  !haveSameTiles(packOp, unPackOp))
3702  return failure();
3703  rewriter.replaceOp(packOp, unPackOp.getSource());
3704  return success();
3705 }
3706 
3707 //===----------------------------------------------------------------------===//
3708 // UnPackOp
3709 //===----------------------------------------------------------------------===//
3710 
3711 void UnPackOp::getAsmResultNames(
3712  function_ref<void(Value, StringRef)> setNameFn) {
3713  setNameFn(getResult(), "unpack");
3714 }
3715 
3718  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3719  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3720 }
3721 
3722 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
3723  return getDimAndTileMappingImpl(*this);
3724 }
3725 
3726 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
3727  return getMixedTilesImpl(*this);
3728 }
3729 
3730 SmallVector<int64_t> UnPackOp::getStaticTiles() {
3731  return getStaticTilesImpl(*this);
3732 }
3733 
3735  return commonVerifierPackAndUnPackOp(*this);
3736 }
3737 
3738 Speculation::Speculatability UnPackOp::getSpeculatability() {
3739  // See PackOp::getSpeculatability.
3740  if (!areTilesAndTiledDimsAllConstant(*this))
3742 
3744 }
3745 
3746 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
3747  Value dest, ArrayRef<int64_t> innerDimsPos,
3748  ArrayRef<OpFoldResult> innerTiles,
3749  ArrayRef<int64_t> outerDimsPerm) {
3750  assert(innerDimsPos.size() == innerTiles.size() &&
3751  "number of tile sizes specified must match the specified number of "
3752  "original dimensions to be tiled");
3753  SmallVector<int64_t> staticTileSizes;
3754  SmallVector<Value> dynamicTileSizes;
3755  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3756  build(builder, state, dest.getType(), source, dest,
3757  outerDimsPerm.empty() ? nullptr
3758  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3759  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3760  builder.getDenseI64ArrayAttr(staticTileSizes));
3761 }
3762 
3763 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
3764  Value transposedSource,
3765  ArrayRef<int64_t> innerPermutation,
3766  ArrayRef<int64_t> outerPermutation) {
3767  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3768  *this, innerPermutation, outerPermutation);
3769  return b.create<UnPackOp>(loc, transposedSource, getDest(),
3770  metadata.innerDimsPos, metadata.innerTiles,
3771  metadata.outerDimsPerm);
3772 }
3773 
3774 /// pack(unpack(x)) -> x
3775 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
3776  PatternRewriter &rewriter) {
3777  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3778  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3779  return failure();
3780  if (packOp.getPaddingValue() ||
3781  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
3782  !haveSameTiles(packOp, unPackOp))
3783  return failure();
3784  rewriter.replaceOp(unPackOp, packOp.getSource());
3785  return success();
3786 }
3787 
3788 //===----------------------------------------------------------------------===//
3789 // Common Canonicalizers and Folders.
3790 //===----------------------------------------------------------------------===//
3791 
3792 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
3793 /// the `tensor.cast` has source that is more static than the consuming op.
3794 ///
3795 /// Example:
3796 /// ```mlir
3797 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3798 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
3799 /// ```
3800 ///
3801 /// folds into:
3802 ///
3803 /// ```mlir
3804 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
3805 /// ```
3806 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
3807 /// can add the pattern to their canonicalizers.
3809  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
3811  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
3812 
3813  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
3814  PatternRewriter &rewriter) const override {
3815  // InsertSliceOp has its own logic about folding tensor.cast ops.
3816  if (isa<InsertSliceOp>(op.getOperation()))
3817  return failure();
3818 
3819  // If no operand comes from a tensor::CastOp and can be folded then fail.
3820  bool hasTensorCastOperand =
3821  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
3822  if (opOperand.get().isa<BlockArgument>())
3823  return false;
3824  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3825  return castOp && canFoldIntoConsumerOp(castOp);
3826  });
3827  if (!hasTensorCastOperand)
3828  return failure();
3829 
3830  SmallVector<Type, 4> newResultTypes;
3831  newResultTypes.reserve(op->getNumResults());
3832  SmallVector<Value, 4> newOperands;
3833  newOperands.reserve(op->getNumOperands());
3834  for (OpOperand &opOperand : op->getOpOperands()) {
3835  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3836  bool fold = canFoldIntoConsumerOp(tensorCastOp);
3837  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
3838  if (op.isDpsInit(&opOperand) &&
3839  !newOperands.back().getType().isa<MemRefType>())
3840  newResultTypes.push_back(newOperands.back().getType());
3841  }
3842 
3843  // Clone op.
3844  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
3845  SmallVector<Value, 4> replacements;
3846  replacements.reserve(newOp->getNumResults());
3847  for (auto [oldResult, newResult] :
3848  llvm::zip(op->getResults(), newOp->getResults())) {
3849  if (newResult.getType() != oldResult.getType()) {
3850  replacements.push_back(rewriter.create<tensor::CastOp>(
3851  op->getLoc(), oldResult.getType(), newResult));
3852  } else {
3853  replacements.push_back(newResult);
3854  }
3855  }
3856  rewriter.replaceOp(op, replacements);
3857 
3858  return success();
3859  }
3860 };
3861 
3862 //===----------------------------------------------------------------------===//
3863 // TensorDialect
3864 //===----------------------------------------------------------------------===//
3865 
3866 void TensorDialect::getCanonicalizationPatterns(
3867  RewritePatternSet &results) const {
3868  results.add<FoldTensorCastProducerOp>(getContext());
3869 }
3870 
3871 //===----------------------------------------------------------------------===//
3872 // TableGen'd op method definitions
3873 //===----------------------------------------------------------------------===//
3874 
3875 #define GET_OP_CLASSES
3876 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
Operation::operand_range getIndices(Operation *op)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:3643
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:281
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:957
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3388
static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, ShapedType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2144
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2)
Definition: TensorOps.cpp:1348
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: TensorOps.cpp:1755
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3208
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3239
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2463
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3525
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3510
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3223
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2183
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2034
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2056
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3269
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1223
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3682
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3284
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3188
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3252
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2459
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:115
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3674
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2203
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1361
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:821
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:166
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:121
UnitAttr getUnitAttr()
Definition: Builders.cpp:111
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:169
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:343
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:125
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:68
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:520
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:405
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:501
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
This class represents an operand of an operation.
Definition: Value.h:255
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:41
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:234
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:550
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:362
result_range getResults()
Definition: Operation.h:394
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:520
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:216
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:227
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:80
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
bool isIndex() const
Definition: Types.cpp:49
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:112
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:278
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:252
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:1994
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:243
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:213
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2436
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2082
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:64
void populateSimplifyTensorPack(RewritePatternSet &patterns)
Patterns to simplify tensor.pack.
Definition: TensorOps.cpp:3182
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:146
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:165
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:49
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:99
This header declares functions that assit transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:348
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1219
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1618
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:345
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:57
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:45
LogicalResult foldDynamicIndexList(Builder &b, SmallVectorImpl< OpFoldResult > &ofrs)
Returns success when any of the elements in ofrs was produced by arith::ConstantIndexOp.
Definition: Utils.cpp:29
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:248
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:3809
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:3813
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2013
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2014
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2001
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2002
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.