MLIR  16.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include <algorithm>
29 
30 using namespace mlir;
31 using namespace mlir::tensor;
32 
33 /// Materialize a single constant operation from a given attribute value with
34 /// the desired resultant type.
36  Attribute value, Type type,
37  Location loc) {
38  if (arith::ConstantOp::isBuildableWith(value, type))
39  return builder.create<arith::ConstantOp>(loc, value, type);
40  if (complex::ConstantOp::isBuildableWith(value, type))
41  return builder.create<complex::ConstantOp>(loc, type,
42  value.cast<ArrayAttr>());
43  return nullptr;
44 }
45 
47  Location loc, Value value) {
48  auto tensorType = value.getType().cast<RankedTensorType>();
50  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
51  if (tensorType.isDynamicDim(i)) {
52  Value size = builder.create<tensor::DimOp>(loc, value, i);
53  result.push_back(size);
54  } else {
55  result.push_back(builder.getIndexAttr(tensorType.getDimSize(i)));
56  }
57  }
58  return result;
59 }
60 
62  OpResult opResult) {
63  auto tensorType = opResult.getType().dyn_cast<TensorType>();
64  assert(tensorType && "expected tensor type");
65 
66  // If the op has a destination, it implements DestinationStyleOpInterface and
67  // we can query the destination operand from that interface.
68  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
69  if (destOp)
70  return destOp.getTiedOpOperand(opResult)->get();
71 
72  // Otherwise, create a new destination tensor with the same shape.
74  b.setInsertionPoint(opResult.getDefiningOp());
75 
76  // Compute sizes.
77  SmallVector<OpFoldResult> mixedSizes;
78  if (!tensorType.hasStaticShape()) {
79  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
80  ReifiedRankedShapedTypeDims reifiedShapes;
81  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
82  dyn_cast<ReifyRankedShapedTypeOpInterface>(opResult.getDefiningOp());
83  if (!reifyShapedTypeInterface)
84  return failure();
85  if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes)))
86  return failure();
87  mixedSizes = getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]);
88  } else {
89  // Static shape: Take static sizes directly.
90  for (int64_t sz : tensorType.getShape())
91  mixedSizes.push_back(b.getIndexAttr(sz));
92  }
93 
94  // Create empty tensor.
95  Value emptyTensor =
96  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
97  return emptyTensor;
98 }
99 
101  Operation *op,
102  SmallVector<Value> &result) {
103  for (OpResult opResult : op->getResults()) {
104  if (opResult.getType().isa<TensorType>()) {
105  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
106  if (failed(destination))
107  return failure();
108  result.push_back(*destination);
109  }
110  }
111  return success();
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // CastOp
116 //===----------------------------------------------------------------------===//
117 
118 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
119  setNameFn(getResult(), "cast");
120 }
121 
122 /// Returns true if `target` is a ranked tensor type that preserves static
123 /// information available in the `source` ranked tensor type.
125  auto sourceType = source.dyn_cast<RankedTensorType>();
126  auto targetType = target.dyn_cast<RankedTensorType>();
127 
128  // Requires RankedTensorType.
129  if (!sourceType || !targetType)
130  return false;
131 
132  // Requires same elemental type.
133  if (sourceType.getElementType() != targetType.getElementType())
134  return false;
135 
136  // Requires same rank.
137  if (sourceType.getRank() != targetType.getRank())
138  return false;
139 
140  // If cast is towards more static sizes along any dimension, don't fold.
141  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
142  if (!ShapedType::isDynamic(std::get<0>(t)) &&
143  ShapedType::isDynamic(std::get<1>(t)))
144  return false;
145  }
146 
147  return true;
148 }
149 
150 /// Determines whether tensor::CastOp casts to a more dynamic version of the
151 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
152 /// implement canonicalization patterns for ops in different dialects that may
153 /// consume the results of tensor.cast operations. Such foldable tensor.cast
154 /// operations are typically inserted as `slice` ops and are canonicalized,
155 /// to preserve the type compatibility of their uses.
156 ///
157 /// Returns true when all conditions are met:
158 /// 1. source and result are ranked tensors with same element type and rank.
159 /// 2. the tensor type has more static information than the result
160 ///
161 /// Example:
162 /// ```mlir
163 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
164 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
165 /// ```
166 ///
167 /// folds into:
168 ///
169 /// ```mlir
170 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
171 /// ```
173  if (!castOp)
174  return false;
175 
176  // Can fold if the source of cast has at least as much static information as
177  // its results.
178  return preservesStaticInformation(castOp.getType(),
179  castOp.getSource().getType());
180 }
181 
182 /// Determines whether the tensor::CastOp casts to a more static version of the
183 /// source tensor. This is useful to fold into a producing op and implement
184 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
185 /// being from different dialects. Returns true when all conditions are met:
186 /// 1. source and result and ranked tensors with same element type and rank.
187 /// 2. the result type has more static information than the source.
188 ///
189 /// Example:
190 /// ```mlir
191 /// %1 = producer ... : tensor<?x?xf32>
192 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
193 /// ```
194 ///
195 /// can be canonicalized to :
196 ///
197 /// ```mlir
198 /// %2 = producer ... : tensor<8x16xf32>
199 /// ```
200 /// Not all ops might be canonicalizable this way, but for those that can be,
201 /// this method provides a check that it is worth doing the canonicalization.
203  if (!castOp)
204  return false;
205  return preservesStaticInformation(castOp.getSource().getType(),
206  castOp.getType());
207 }
208 
209 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
210 /// that can be folded.
212  bool folded = false;
213  for (OpOperand &operand : op->getOpOperands()) {
214  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
215  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
216  operand.set(castOp.getOperand());
217  folded = true;
218  }
219  }
220  return success(folded);
221 }
222 
223 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
224  if (inputs.size() != 1 || outputs.size() != 1)
225  return false;
226  Type a = inputs.front(), b = outputs.front();
227  auto aT = a.dyn_cast<TensorType>();
228  auto bT = b.dyn_cast<TensorType>();
229  if (!aT || !bT)
230  return false;
231 
232  if (aT.getElementType() != bT.getElementType())
233  return false;
234 
235  return succeeded(verifyCompatibleShape(aT, bT));
236 }
237 
238 /// Compute a TensorType that has the joined shape knowledge of the two
239 /// given TensorTypes. The element types need to match.
241  assert(one.getElementType() == two.getElementType());
242 
243  if (!one.hasRank())
244  return two;
245  if (!two.hasRank())
246  return one;
247 
248  int64_t rank = one.getRank();
249  if (rank != two.getRank())
250  return {};
251 
253  join.reserve(rank);
254  for (int64_t i = 0; i < rank; ++i) {
255  if (one.isDynamicDim(i)) {
256  join.push_back(two.getDimSize(i));
257  continue;
258  }
259  if (two.isDynamicDim(i)) {
260  join.push_back(one.getDimSize(i));
261  continue;
262  }
263  if (one.getDimSize(i) != two.getDimSize(i))
264  return {};
265  join.push_back(one.getDimSize(i));
266  }
267  return RankedTensorType::get(join, one.getElementType());
268 }
269 
270 namespace {
271 
272 /// Replaces chains of two tensor.cast operations by a single tensor.cast
273 /// operation if doing so does not remove runtime constraints.
274 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
276 
277  LogicalResult matchAndRewrite(CastOp tensorCast,
278  PatternRewriter &rewriter) const final {
279  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
280 
281  if (!tensorCastOperand)
282  return failure();
283 
284  auto sourceType =
285  tensorCastOperand.getOperand().getType().cast<TensorType>();
286  auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
287  auto resultType = tensorCast.getType().cast<TensorType>();
288 
289  // We can remove the intermediate cast if joining all three produces the
290  // same result as just joining the source and result shapes.
291  auto firstJoin =
292  joinShapes(joinShapes(sourceType, intermediateType), resultType);
293 
294  // The join might not exist if the cast sequence would fail at runtime.
295  if (!firstJoin)
296  return failure();
297 
298  // The newJoin always exists if the above join exists, it might just contain
299  // less information. If so, we cannot drop the intermediate cast, as doing
300  // so would remove runtime checks.
301  auto newJoin = joinShapes(sourceType, resultType);
302  if (firstJoin != newJoin)
303  return failure();
304 
305  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
306  tensorCastOperand.getOperand());
307  return success();
308  }
309 };
310 
311 /// Fold tensor.cast into tesor.extract_slice producer.
312 /// Example:
313 /// ```
314 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
315 /// tensor<128x512xf32> to tensor<?x512xf32>
316 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
317 /// ```
318 /// ->
319 /// ```
320 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
321 /// tensor<128x512xf32> to tensor<16x512xf32>
322 /// ```
323 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
325 
326  LogicalResult matchAndRewrite(CastOp tensorCast,
327  PatternRewriter &rewriter) const final {
328  auto extractOperand =
329  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
330 
331  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
332  tensorCast.getType().getShape() == tensorCast.getSource()
333  .getType()
334  .cast<RankedTensorType>()
335  .getShape())
336  return failure();
337 
338  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
339  auto dimMask = computeRankReductionMask(
340  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
341  size_t dimIndex = 0;
342  for (size_t i = 0, e = sizes.size(); i < e; i++) {
343  if (dimMask && dimMask->count(i))
344  continue;
345  int64_t dim = tensorCast.getType().getShape()[dimIndex++];
346  if (ShapedType::isDynamic(dim))
347  continue;
348  sizes[i] = rewriter.getIndexAttr(dim);
349  }
350 
351  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
352  tensorCast, tensorCast.getType().cast<RankedTensorType>(),
353  extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
354  extractOperand.getMixedStrides());
355  return success();
356  }
357 };
358 
359 } // namespace
360 
361 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
362  MLIRContext *context) {
363  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // DimOp
368 //===----------------------------------------------------------------------===//
369 
370 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
371  setNameFn(getResult(), "dim");
372 }
373 
374 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
375  int64_t index) {
376  auto loc = result.location;
377  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
378  build(builder, result, source, indexValue);
379 }
380 
381 Optional<int64_t> DimOp::getConstantIndex() {
382  return getConstantIntValue(getIndex());
383 }
384 
385 Speculation::Speculatability DimOp::getSpeculatability() {
386  auto constantIndex = getConstantIndex();
387  if (!constantIndex)
389 
390  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
391  if (!rankedSourceType)
393 
394  // The verifier rejects operations that violate this assertion.
395  assert(constantIndex < rankedSourceType.getRank());
397 }
398 
400  // Assume unknown index to be in range.
401  Optional<int64_t> index = getConstantIndex();
402  if (!index)
403  return success();
404 
405  // Check that constant index is not knowingly out of range.
406  auto type = getSource().getType();
407  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
408  if (*index >= tensorType.getRank())
409  return emitOpError("index is out of range");
410  } else if (type.isa<UnrankedTensorType>()) {
411  // Assume index to be in range.
412  } else {
413  llvm_unreachable("expected operand with tensor type");
414  }
415  return success();
416 }
417 
418 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
419  // All forms of folding require a known index.
420  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
421  if (!index)
422  return {};
423 
424  // Folding for unranked types (UnrankedTensorType) is not supported.
425  auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
426  if (!tensorType)
427  return {};
428 
429  // Fold if the shape extent along the given index is known.
430  if (!tensorType.isDynamicDim(index.getInt())) {
431  Builder builder(getContext());
432  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
433  }
434 
435  Operation *definingOp = getSource().getDefiningOp();
436 
437  // Fold dim to the operand of tensor.generate.
438  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
439  auto resultType =
440  fromElements.getResult().getType().cast<RankedTensorType>();
441  // The case where the type encodes the size of the dimension is handled
442  // above.
443  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
444 
445  // Find the operand of the fromElements that corresponds to this index.
446  auto dynExtents = fromElements.getDynamicExtents().begin();
447  for (auto dim : resultType.getShape().take_front(index.getInt()))
448  if (ShapedType::isDynamic(dim))
449  dynExtents++;
450 
451  return Value{*dynExtents};
452  }
453 
454  // The size at the given index is now known to be a dynamic size.
455  unsigned unsignedIndex = index.getValue().getZExtValue();
456 
457  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
458  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
459  // `resolve-shaped-type-result-dims` pass.
460  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
461  sliceOp.isDynamicSize(unsignedIndex)) {
462  return {sliceOp.getDynamicSize(unsignedIndex)};
463  }
464  }
465 
466  // dim(cast) -> dim
467  if (succeeded(foldTensorCast(*this)))
468  return getResult();
469 
470  return {};
471 }
472 
473 namespace {
474 /// Fold dim of a cast into the dim of the source of the tensor cast.
475 struct DimOfCastOp : public OpRewritePattern<DimOp> {
477 
478  LogicalResult matchAndRewrite(DimOp dimOp,
479  PatternRewriter &rewriter) const override {
480  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
481  if (!castOp)
482  return failure();
483  Value newSource = castOp.getOperand();
484  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
485  return success();
486  }
487 };
488 } // namespace
489 
490 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
491  MLIRContext *context) {
492  results.add<DimOfCastOp>(context);
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // EmptyOp
497 //===----------------------------------------------------------------------===//
498 
499 void EmptyOp::build(OpBuilder &builder, OperationState &result,
500  ArrayRef<int64_t> staticShape, Type elementType,
501  Attribute encoding) {
502  assert(all_of(staticShape,
503  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
504  "expected only static sizes");
505  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
506 }
507 
508 void EmptyOp::build(OpBuilder &builder, OperationState &result,
509  ArrayRef<int64_t> staticShape, Type elementType,
510  ValueRange dynamicSizes, Attribute encoding) {
511  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
512  build(builder, result, tensorType, dynamicSizes);
513 }
514 
515 void EmptyOp::build(OpBuilder &builder, OperationState &result,
516  ArrayRef<OpFoldResult> sizes, Type elementType,
517  Attribute encoding) {
518  SmallVector<int64_t> staticShape;
519  SmallVector<Value> dynamicSizes;
520  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape,
521  ShapedType::kDynamic);
522  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
523 }
524 
526  if (getType().getNumDynamicDims() !=
527  static_cast<int64_t>(getDynamicSizes().size()))
528  return emitOpError("incorrect number of dynamic sizes, has ")
529  << getDynamicSizes().size() << ", expected "
530  << getType().getNumDynamicDims();
531  return success();
532 }
533 
535 EmptyOp::reifyResultShapes(OpBuilder &builder,
536  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
537  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
538  unsigned ctr = 0;
539  for (int64_t i = 0; i < getType().getRank(); ++i) {
540  if (getType().isDynamicDim(i)) {
541  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
542  } else {
543  reifiedReturnShapes[0][i] =
544  builder.create<arith::ConstantIndexOp>(getLoc(), i);
545  }
546  }
547  return success();
548 }
549 
550 Value EmptyOp::getDynamicSize(unsigned idx) {
551  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
552  unsigned ctr = 0;
553  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
554  if (getType().isDynamicDim(i))
555  ++ctr;
556  return getDynamicSizes()[ctr];
557 }
558 
561  unsigned ctr = 0;
562  OpBuilder b(getContext());
563  for (int64_t i = 0; i < getType().getRank(); ++i) {
564  if (getType().isDynamicDim(i)) {
565  result.push_back(getDynamicSizes()[ctr++]);
566  } else {
567  result.push_back(b.getIndexAttr(getType().getShape()[i]));
568  }
569  }
570  return result;
571 }
572 
573 namespace {
574 /// Change the type of the result of a `tensor.empty` by making the result
575 /// type statically sized along dimensions that in the original operation were
576 /// defined as dynamic, but the size was defined using a `constant` op. For
577 /// example
578 ///
579 /// %c5 = arith.constant 5: index
580 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
581 ///
582 /// to
583 ///
584 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
585 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
587 
588  LogicalResult matchAndRewrite(EmptyOp op,
589  PatternRewriter &rewriter) const override {
590  SmallVector<int64_t> staticShape(op.getType().getShape().begin(),
591  op.getType().getShape().end());
592  SmallVector<Value> dynamicSizes;
593 
594  // Compute new static and dynamic sizes.
595  unsigned ctr = 0;
596  bool changedType = false;
597  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
598  if (op.getType().isDynamicDim(i)) {
599  Value dynamicSize = op.getDynamicSizes()[ctr++];
600  Optional<int64_t> cst = getConstantIntValue(dynamicSize);
601  if (cst.has_value()) {
602  staticShape[i] = *cst;
603  changedType = true;
604  } else {
605  dynamicSizes.push_back(dynamicSize);
606  }
607  }
608  }
609 
610  // Stop here if no dynamic size was promoted to static.
611  if (!changedType)
612  return failure();
613 
614  auto tensorType = RankedTensorType::get(
615  staticShape, op.getType().getElementType(), op.getType().getEncoding());
616  auto newOp =
617  rewriter.create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
618  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
619  return success();
620  }
621 };
622 
623 /// `tensor.empty` does not define any tensor contents, so a slice of a
624 /// `tensor.empty` can be canonicalized to a smaller `tensor.empty`.
625 struct FoldEmptyTensorWithExtractSliceOp
626  : public OpRewritePattern<ExtractSliceOp> {
628 
629  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
630  PatternRewriter &rewriter) const override {
631  if (!sliceOp.getSource().getDefiningOp<EmptyOp>())
632  return failure();
633 
634  // ExtractSliceOp may be rank-reducing; its dynamic sizes must be
635  // preserved as well as its result type.
636  auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
637  sliceOp.getType().getElementType(),
638  sliceOp.getType().getEncoding());
639  rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
640  sliceOp.getSizes());
641  return success();
642  }
643 };
644 
645 template <typename ReshapeOp>
646 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
648 
649  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
650  PatternRewriter &rewriter) const override {
651  if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>())
652  return failure();
653  Location loc = reshapeOp.getLoc();
654  ReifiedRankedShapedTypeDims resultShapes;
655  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
656  cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
657  if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
658  resultShapes)) ||
659  !llvm::hasSingleElement(resultShapes))
660  return failure();
661  // TODO: Do not drop tensor type encoding.
662  Value emptyTensor =
663  rewriter.create<EmptyOp>(loc, getAsOpFoldResult(resultShapes[0]),
664  reshapeOp.getResultType().getElementType());
665  if (emptyTensor.getType() != reshapeOp.getResultType()) {
666  rewriter.replaceOpWithNewOp<tensor::CastOp>(
667  reshapeOp, reshapeOp.getResultType(), emptyTensor);
668  } else {
669  rewriter.replaceOp(reshapeOp, emptyTensor);
670  }
671  return success();
672  }
673 };
674 
675 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
677 
678  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
679  PatternRewriter &rewriter) const override {
680  Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
681  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
682  if (!emptyTensorOp || !maybeConstantIndex)
683  return failure();
684  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
685  return failure();
686  rewriter.replaceOp(dimOp,
687  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
688  return success();
689  }
690 };
691 
692 /// Canonicalize
693 ///
694 /// ```mlir
695 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
696 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
697 /// ```
698 ///
699 /// into
700 ///
701 /// ```mlir
702 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
703 /// ```
704 ///
705 /// This assumes the input program is correct in terms of its shape. So it is
706 /// safe to assume that `%d0` is in fact 4.
707 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
709 
710  LogicalResult matchAndRewrite(CastOp castOp,
711  PatternRewriter &rewriter) const override {
712  if (!canFoldIntoProducerOp(castOp))
713  return failure();
714  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
715  if (!producer)
716  return failure();
717 
718  auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
719  ArrayRef<int64_t> resultShape = resultType.getShape();
720  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
721  SmallVector<OpFoldResult> newMixedSizes;
722  newMixedSizes.reserve(currMixedSizes.size());
723  assert(resultShape.size() == currMixedSizes.size() &&
724  "mismatch in result shape and sizes of empty op");
725  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
726  int64_t newDim = std::get<0>(it);
727  OpFoldResult currDim = std::get<1>(it);
728  // Case 1: The empty tensor dim is static. Check that the tensor cast
729  // result dim matches.
730  if (auto attr = currDim.dyn_cast<Attribute>()) {
731  if (ShapedType::isDynamic(newDim) ||
732  newDim != attr.cast<IntegerAttr>().getInt()) {
733  // Something is off, the cast result shape cannot be more dynamic
734  // than the empty tensor result shape (enforced by
735  // `canFoldIntoProducer`). Abort for now.
736  return rewriter.notifyMatchFailure(
737  producer, "mismatch in static value of shape of empty tensor "
738  "result and cast result");
739  }
740  newMixedSizes.push_back(attr);
741  continue;
742  }
743 
744  // Case 2 : The tensor cast shape is static, but empty tensor result
745  // shape is dynamic.
746  if (!ShapedType::isDynamic(newDim)) {
747  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
748  continue;
749  }
750 
751  // Case 3 : The tensor cast shape is dynamic and empty tensor result
752  // shape is dynamic. Use the dynamic value from the empty tensor op.
753  newMixedSizes.push_back(currDim);
754  }
755 
756  // TODO: Do not drop tensor encoding.
757  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
758  resultType.getElementType());
759  return success();
760  }
761 };
762 
763 } // namespace
764 
765 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
766  MLIRContext *context) {
767  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
768  FoldEmptyTensorWithExtractSliceOp,
769  FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
770  FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>,
771  ReplaceEmptyTensorStaticShapeDims>(context);
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // ExtractOp
776 //===----------------------------------------------------------------------===//
777 
778 namespace {
779 
780 /// Canonicalizes the pattern of the form
781 ///
782 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
783 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
784 ///
785 /// to
786 ///
787 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
788 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
790 
791  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
792  PatternRewriter &rewriter) const final {
793  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
794  if (!tensorCast)
795  return failure();
796  if (!tensorCast.getSource().getType().isa<RankedTensorType>())
797  return failure();
798  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
799  extract, tensorCast.getSource(), extract.getIndices());
800  return success();
801  }
802 };
803 
804 } // namespace
805 
806 void ExtractOp::getAsmResultNames(
807  function_ref<void(Value, StringRef)> setNameFn) {
808  setNameFn(getResult(), "extracted");
809 }
810 
812  // Verify the # indices match if we have a ranked type.
813  auto tensorType = getTensor().getType().cast<RankedTensorType>();
814  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
815  return emitOpError("incorrect number of indices for extract_element");
816  return success();
817 }
818 
819 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
820  // If this is a splat elements attribute, simply return the value. All of
821  // the elements of a splat attribute are the same.
822  if (Attribute tensor = operands.front())
823  if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
824  return splatTensor.getSplatValue<Attribute>();
825 
826  // Collect the constant indices into the tensor.
827  SmallVector<uint64_t, 8> indices;
828  for (Attribute indice : llvm::drop_begin(operands, 1)) {
829  if (!indice || !indice.isa<IntegerAttr>())
830  return {};
831  indices.push_back(indice.cast<IntegerAttr>().getInt());
832  }
833 
834  // Fold extract(from_elements(...)).
835  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
836  auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
837  auto rank = tensorType.getRank();
838  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
839  "rank mismatch");
840  int flatIndex = 0;
841  int stride = 1;
842  for (int i = rank - 1; i >= 0; --i) {
843  if (i < rank - 1)
844  stride *= tensorType.getDimSize(i);
845  flatIndex += indices[i] * stride;
846  }
847  // Prevent out of bounds accesses. This can happen in invalid code that
848  // will never execute.
849  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
850  flatIndex < 0)
851  return {};
852  return fromElementsOp.getElements()[flatIndex];
853  }
854 
855  // If this is an elements attribute, query the value at the given indices.
856  if (Attribute tensor = operands.front()) {
857  auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
858  if (elementsAttr && elementsAttr.isValidIndex(indices))
859  return elementsAttr.getValues<Attribute>()[indices];
860  }
861 
862  return {};
863 }
864 
865 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
866  MLIRContext *context) {
867  results.add<ExtractFromTensorCast>(context);
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // FromElementsOp
872 //===----------------------------------------------------------------------===//
873 
874 void FromElementsOp::getAsmResultNames(
875  function_ref<void(Value, StringRef)> setNameFn) {
876  setNameFn(getResult(), "from_elements");
877 }
878 
879 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
880  Type resultType, ValueRange elements) {
881  result.addOperands(elements);
882  result.addTypes(resultType);
883 }
884 
885 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
886  ValueRange elements) {
887  assert(!elements.empty() && "expected at least one element");
888  Type resultType = RankedTensorType::get(
889  {static_cast<int64_t>(elements.size())}, elements.front().getType());
890  build(builder, result, resultType, elements);
891 }
892 
893 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
894  if (!llvm::is_contained(operands, nullptr))
895  return DenseElementsAttr::get(getType(), operands);
896  return {};
897 }
898 
899 namespace {
900 
901 // Pushes the index_casts that occur before extractions to after the extract.
902 // This minimizes type conversion in some cases and enables the extract
903 // canonicalizer. This changes:
904 //
905 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
906 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
907 //
908 // to the following:
909 //
910 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
911 // %cast = arith.index_cast %extract : i32 to index
912 //
913 // to just %element.
914 //
915 // Consider expanding this to a template and handle all tensor cast
916 // operations.
917 struct ExtractElementFromIndexCast
918  : public OpRewritePattern<tensor::ExtractOp> {
920 
921  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
922  PatternRewriter &rewriter) const final {
923  Location loc = extract.getLoc();
924  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
925  if (!indexCast)
926  return failure();
927 
928  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
929 
930  auto newExtract = rewriter.create<tensor::ExtractOp>(
931  loc, elementTy, indexCast.getIn(), extract.getIndices());
932 
933  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
934  newExtract);
935 
936  return success();
937  }
938 };
939 
940 } // namespace
941 
942 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
943  MLIRContext *context) {
944  results.add<ExtractElementFromIndexCast>(context);
945 }
946 
947 //===----------------------------------------------------------------------===//
948 // GatherOp
949 //===----------------------------------------------------------------------===//
950 
951 void GatherOp::getAsmResultNames(
952  function_ref<void(Value, StringRef)> setNameFn) {
953  setNameFn(getResult(), "gather");
954 }
955 
956 /// Return the inferred result type for a gatherOp where:
957 /// - sourceType is the type of the source tensor gathered from
958 /// - indicesType is the type of the indices used to gather
959 /// - gatherDims are the dims along which the gather occurs.
960 /// Return a full rank or ranked-reduced variant of the type depending on
961 /// the value of rankReduced.
962 ///
963 /// The leading dimensions of the index tensor give the result tensor its
964 /// leading dimensions.
965 /// The trailing dimensions of the result tensor are obtained from the source
966 /// tensor by setting the dimensions specified in gather_dims to `1` (if
967 /// rankedReduced is false), or skipping them (otherwise).
968 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
969  RankedTensorType indicesType,
970  ArrayRef<int64_t> gatherDims,
971  bool rankReduced) {
972  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
973  resultShape.reserve(resultShape.size() + sourceType.getRank());
974  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
975  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
976  if (!rankReduced)
977  resultShape.push_back(1);
978  continue;
979  }
980  resultShape.push_back(sourceType.getDimSize(idx));
981  }
982  return RankedTensorType::Builder(sourceType).setShape(resultShape);
983 }
984 
985 static LogicalResult
987  StringRef gatherOrScatter, StringRef sourceOrDest) {
988  if (dims.empty())
989  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
990 
991  int64_t numGatherDims = dims.size();
992  if (numGatherDims > rank)
993  return op->emitOpError(gatherOrScatter)
994  << "_dims overflow " << sourceOrDest << " rank";
995  for (int64_t val : dims) {
996  if (val < 0)
997  return op->emitOpError(gatherOrScatter)
998  << "_dims value must be non-negative";
999  if (val >= rank)
1000  return op->emitOpError(gatherOrScatter)
1001  << "_dims value must be smaller than " << sourceOrDest << " rank";
1002  }
1003  for (int64_t i = 1; i < numGatherDims; ++i) {
1004  if (dims[i - 1] >= dims[i])
1005  return op->emitOpError(gatherOrScatter)
1006  << "_dims values must be strictly increasing";
1007  }
1008  return success();
1009 }
1010 
1012  int64_t sourceRank = getSourceType().getRank();
1013  ArrayRef<int64_t> gatherDims = getGatherDims();
1014  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1015  "gather", "source")))
1016  return failure();
1017 
1018  RankedTensorType expectedResultType = GatherOp::inferResultType(
1019  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1020  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1021  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1022  if (getResultType() != expectedResultType &&
1023  getResultType() != expectedRankReducedResultType) {
1024  return emitOpError("result type "
1025  "mismatch: "
1026  "expected ")
1027  << expectedResultType << " or its rank-reduced variant "
1028  << expectedRankReducedResultType << " (got: " << getResultType()
1029  << ")";
1030  }
1031 
1032  return success();
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // InsertOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 void InsertOp::getAsmResultNames(
1040  function_ref<void(Value, StringRef)> setNameFn) {
1041  setNameFn(getResult(), "inserted");
1042 }
1043 
1045  // Verify the # indices match if we have a ranked type.
1046  auto destType = getDest().getType().cast<RankedTensorType>();
1047  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1048  return emitOpError("incorrect number of indices");
1049  return success();
1050 }
1051 
1052 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
1053  Attribute scalar = operands[0];
1054  Attribute dest = operands[1];
1055  if (scalar && dest)
1056  if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
1057  if (scalar == splatDest.getSplatValue<Attribute>())
1058  return dest;
1059  return {};
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // GenerateOp
1064 //===----------------------------------------------------------------------===//
1065 
1066 void GenerateOp::getAsmResultNames(
1067  function_ref<void(Value, StringRef)> setNameFn) {
1068  setNameFn(getResult(), "generated");
1069 }
1070 
1071 LogicalResult GenerateOp::reifyResultShapes(
1072  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1073  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1074  int idx = 0;
1075  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1076  if (getType().isDynamicDim(dim)) {
1077  reifiedReturnShapes[0][dim] = getOperand(idx++);
1078  } else {
1079  reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
1080  getLoc(), getType().getDimSize(dim));
1081  }
1082  }
1083  return success();
1084 }
1085 
1087  // Ensure that the tensor type has as many dynamic dimensions as are
1088  // specified by the operands.
1089  RankedTensorType resultTy = getType().cast<RankedTensorType>();
1090  if (getNumOperands() != resultTy.getNumDynamicDims())
1091  return emitError("must have as many index operands as dynamic extents "
1092  "in the result type");
1093 
1094  return success();
1095 }
1096 
1097 LogicalResult GenerateOp::verifyRegions() {
1098  RankedTensorType resultTy = getType().cast<RankedTensorType>();
1099  // Ensure that region arguments span the index space.
1100  if (!llvm::all_of(getBody().getArgumentTypes(),
1101  [](Type ty) { return ty.isIndex(); }))
1102  return emitError("all body arguments must be index");
1103  if (getBody().getNumArguments() != resultTy.getRank())
1104  return emitError("must have one body argument per input dimension");
1105 
1106  // Ensure that the region yields an element of the right type.
1107  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1108 
1109  if (yieldOp.getValue().getType() != resultTy.getElementType())
1110  return emitOpError(
1111  "body must be terminated with a `yield` operation of the tensor "
1112  "element type");
1113 
1114  return success();
1115 }
1116 
1117 void GenerateOp::build(
1118  OpBuilder &b, OperationState &result, Type resultTy,
1119  ValueRange dynamicExtents,
1120  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1121  build(b, result, resultTy, dynamicExtents);
1122 
1123  // Build and populate body.
1124  OpBuilder::InsertionGuard guard(b);
1125  Region *bodyRegion = result.regions.front().get();
1126  auto rank = resultTy.cast<RankedTensorType>().getRank();
1127  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1128  SmallVector<Location, 2> argumentLocs(rank, result.location);
1129  Block *bodyBlock =
1130  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1131  bodyBuilder(b, result.location, bodyBlock->getArguments());
1132 }
1133 
1134 namespace {
1135 
1136 /// Canonicalizes tensor.generate operations with a constant
1137 /// operand into the equivalent operation with the operand expressed in the
1138 /// result type, instead. We also insert a type cast to make sure that the
1139 /// resulting IR is still well-typed.
1140 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1142 
1143  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
1144  PatternRewriter &rewriter) const final {
1145  auto resultType =
1146  tensorFromElements.getResult().getType().cast<RankedTensorType>();
1147 
1148  if (resultType.hasStaticShape())
1149  return failure();
1150 
1151  SmallVector<Value, 4> newOperands;
1152  SmallVector<int64_t, 4> newShape;
1153  auto operandsIt = tensorFromElements.getDynamicExtents().begin();
1154 
1155  for (int64_t dim : resultType.getShape()) {
1156  if (!ShapedType::isDynamic(dim)) {
1157  newShape.push_back(dim);
1158  continue;
1159  }
1160  APInt index;
1161  if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1162  newShape.push_back(ShapedType::kDynamic);
1163  newOperands.push_back(*operandsIt++);
1164  continue;
1165  }
1166  newShape.push_back(index.getSExtValue());
1167  operandsIt++;
1168  }
1169 
1170  if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
1171  return failure();
1172 
1173  auto loc = tensorFromElements.getLoc();
1174  auto newOp = rewriter.create<GenerateOp>(
1175  loc, RankedTensorType::get(newShape, resultType.getElementType()),
1176  newOperands);
1177  rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
1178  newOp.getBody().begin());
1179  rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
1180  newOp);
1181  return success();
1182  }
1183 };
1184 
1185 /// Canonicalizes the pattern of the form
1186 ///
1187 /// %tensor = tensor.generate %x {
1188 /// ^bb0(%arg0: index):
1189 /// <computation>
1190 /// yield %1 : index
1191 /// } : tensor<?xindex>
1192 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1193 ///
1194 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1195 /// tensor.generate operation has no side-effects.
1196 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1198 
1199  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1200  PatternRewriter &rewriter) const final {
1201  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1202  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1203  return failure();
1204 
1205  BlockAndValueMapping mapping;
1206  Block *body = &tensorFromElements.getBody().front();
1207  mapping.map(body->getArguments(), extract.getIndices());
1208  for (auto &op : body->without_terminator())
1209  rewriter.clone(op, mapping);
1210 
1211  auto yield = cast<YieldOp>(body->getTerminator());
1212 
1213  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1214  return success();
1215  }
1216 };
1217 
1218 } // namespace
1219 
1220 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1221  MLIRContext *context) {
1222  // TODO: Move extract pattern to tensor::ExtractOp.
1223  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1224 }
1225 
1226 //===----------------------------------------------------------------------===//
1227 // RankOp
1228 //===----------------------------------------------------------------------===//
1229 
1230 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1231  setNameFn(getResult(), "rank");
1232 }
1233 
1234 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
1235  // Constant fold rank when the rank of the operand is known.
1236  auto type = getOperand().getType();
1237  auto shapedType = type.dyn_cast<ShapedType>();
1238  if (shapedType && shapedType.hasRank())
1239  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1240  return IntegerAttr();
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // ReshapeOp
1245 //===----------------------------------------------------------------------===//
1246 
1247 void ReshapeOp::getAsmResultNames(
1248  function_ref<void(Value, StringRef)> setNameFn) {
1249  setNameFn(getResult(), "reshape");
1250 }
1251 
1252 static int64_t getNumElements(ShapedType type) {
1253  int64_t numElements = 1;
1254  for (auto dim : type.getShape())
1255  numElements *= dim;
1256  return numElements;
1257 }
1258 
1260  TensorType operandType = getSource().getType().cast<TensorType>();
1261  TensorType resultType = getResult().getType().cast<TensorType>();
1262 
1263  if (operandType.getElementType() != resultType.getElementType())
1264  return emitOpError("element types of source and destination tensor "
1265  "types should be the same");
1266 
1267  int64_t shapeSize =
1268  getShape().getType().cast<RankedTensorType>().getDimSize(0);
1269  auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
1270  auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
1271 
1272  if (resultRankedType) {
1273  if (operandRankedType && resultRankedType.hasStaticShape() &&
1274  operandRankedType.hasStaticShape()) {
1275  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1276  return emitOpError("source and destination tensor should have the "
1277  "same number of elements");
1278  }
1279  if (ShapedType::isDynamic(shapeSize))
1280  return emitOpError("cannot use shape operand with dynamic length to "
1281  "reshape to statically-ranked tensor type");
1282  if (shapeSize != resultRankedType.getRank())
1283  return emitOpError(
1284  "length of shape operand differs from the result's tensor rank");
1285  }
1286  return success();
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // Reassociative reshape ops
1291 //===----------------------------------------------------------------------===//
1292 
1293 void CollapseShapeOp::getAsmResultNames(
1294  function_ref<void(Value, StringRef)> setNameFn) {
1295  setNameFn(getResult(), "collapsed");
1296 }
1297 
1298 void ExpandShapeOp::getAsmResultNames(
1299  function_ref<void(Value, StringRef)> setNameFn) {
1300  setNameFn(getResult(), "expanded");
1301 }
1302 
1303 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1304  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1305  "invalid resultDim");
1306  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1307  if (llvm::find(it.value(), resultDim) != it.value().end())
1308  return it.index();
1309  llvm_unreachable("could not find reassociation group");
1310 }
1311 
1312 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1313  return getSymbolLessAffineMaps(getReassociationExprs());
1314 }
1315 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1316  return convertReassociationIndicesToExprs(getContext(),
1317  getReassociationIndices());
1318 }
1319 
1320 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1321  return getSymbolLessAffineMaps(getReassociationExprs());
1322 }
1323 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1324  return convertReassociationIndicesToExprs(getContext(),
1325  getReassociationIndices());
1326 }
1327 
1328 /// Compute the RankedTensorType obtained by applying `reassociation` to
1329 /// `type`.
1330 static RankedTensorType
1332  ArrayRef<AffineMap> reassociation) {
1333  auto shape = type.getShape();
1334  SmallVector<int64_t, 4> newShape;
1335  newShape.reserve(reassociation.size());
1336 
1337  // Use the fact that reassociation is valid to simplify the logic: only use
1338  // each map's rank.
1339  assert(isReassociationValid(reassociation) && "invalid reassociation");
1340  unsigned currentDim = 0;
1341  for (AffineMap m : reassociation) {
1342  unsigned dim = m.getNumResults();
1343  auto band = shape.slice(currentDim, dim);
1344  int64_t size = 1;
1345  if (llvm::is_contained(band, ShapedType::kDynamic))
1346  size = ShapedType::kDynamic;
1347  else
1348  for (unsigned d = 0; d < dim; ++d)
1349  size *= shape[currentDim + d];
1350  newShape.push_back(size);
1351  currentDim += dim;
1352  }
1353 
1354  return RankedTensorType::get(newShape, type.getElementType());
1355 }
1356 
1357 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1358  ArrayRef<ReassociationIndices> reassociation,
1359  ArrayRef<NamedAttribute> attrs) {
1360  auto resultType = computeTensorReshapeCollapsedType(
1361  src.getType().cast<RankedTensorType>(),
1363  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1364  build(b, result, resultType, src, attrs);
1365  result.addAttribute(getReassociationAttrStrName(),
1366  getReassociationIndicesAttribute(b, reassociation));
1367 }
1368 
1369 // Checks if types are the same, but ignoring encoding on ranked tensors.
1370 static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
1371  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
1372  if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
1373  return rtp1.getShape() == rtp2.getShape() &&
1374  rtp1.getElementType() == rtp2.getElementType();
1375  return false;
1376  }
1377  // Default implementation.
1378  return tp1 == tp2;
1379 }
1380 
1381 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1382  TensorReshapeOp, ExpandShapeOp>::value>
1383 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1384  RankedTensorType expandedType,
1385  RankedTensorType collapsedType) {
1386  if (failed(
1387  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1388  return failure();
1389 
1390  auto maps = op.getReassociationMaps();
1391  RankedTensorType expectedType =
1392  computeTensorReshapeCollapsedType(expandedType, maps);
1393  if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
1394  return op.emitOpError("expected collapsed type to be ")
1395  << expectedType << ", but got " << collapsedType;
1396  return success();
1397 }
1398 
1400  auto srcType = getSrcType();
1401  auto resultType = getResultType();
1402  if (srcType.getRank() >= resultType.getRank())
1403  return emitOpError("expected rank expansion, but found source rank ")
1404  << srcType.getRank() << " >= result rank " << resultType.getRank();
1405 
1406  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
1407 }
1408 
1410  auto srcType = getSrcType();
1411  auto resultType = getResultType();
1412  if (srcType.getRank() <= resultType.getRank())
1413  return emitOpError("expected rank reduction, but found source rank ")
1414  << srcType.getRank() << " <= result rank " << resultType.getRank();
1415 
1416  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1417 }
1418 
1419 namespace {
1420 /// Reshape of a splat constant can be replaced with a constant of the result
1421 /// type.
1422 template <typename TensorReshapeOp>
1423 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1425  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1426  PatternRewriter &rewriter) const override {
1427  DenseElementsAttr attr;
1428  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1429  return failure();
1430  if (!attr || !attr.isSplat())
1431  return failure();
1433  reshapeOp.getResultType(), attr.getRawData());
1434  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1435  return success();
1436  }
1437 };
1438 
1439 /// Reshape of a FromElements can be replaced with a FromElements of the
1440 /// result type
1441 template <typename TensorReshapeOp>
1442 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1444  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1445  PatternRewriter &rewriter) const override {
1446  auto fromElements =
1447  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1448  if (!fromElements)
1449  return failure();
1450 
1451  auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
1452 
1453  if (!shapedTy.hasStaticShape())
1454  return failure();
1455 
1456  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1457  fromElements.getElements());
1458  return success();
1459  }
1460 };
1461 
1462 // Fold CastOp into CollapseShapeOp when adding static information.
1463 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1465 
1466  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1467  PatternRewriter &rewriter) const override {
1468  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1469  if (!tensor::canFoldIntoConsumerOp(castOp))
1470  return failure();
1471 
1472  RankedTensorType srcType =
1473  castOp.getSource().getType().cast<RankedTensorType>();
1474  RankedTensorType newResultType = computeTensorReshapeCollapsedType(
1475  srcType, collapseShapeOp.getReassociationMaps());
1476 
1477  if (newResultType == collapseShapeOp.getResultType()) {
1478  rewriter.updateRootInPlace(collapseShapeOp, [&]() {
1479  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1480  });
1481  } else {
1482  auto newOp = rewriter.create<CollapseShapeOp>(
1483  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1484  collapseShapeOp.getReassociation());
1485  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1486  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1487  }
1488  return success();
1489  }
1490 };
1491 
1492 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1494 
1495  LogicalResult matchAndRewrite(DimOp dimOp,
1496  PatternRewriter &rewriter) const override {
1497  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1498  if (!expandShapeOp)
1499  return failure();
1500 
1501  // Only constant dimension values are supported.
1502  Optional<int64_t> dim = dimOp.getConstantIndex();
1503  if (!dim.has_value())
1504  return failure();
1505 
1506  // Skip static dims. These are folded to constant ops.
1507  TensorType resultType = expandShapeOp.getResultType();
1508  if (!resultType.isDynamicDim(*dim))
1509  return failure();
1510 
1511  // Find reassociation group that contains this result dimension.
1512  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1513 
1514  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1515  // ExpandShapeOp would be ambiguous.)
1516  int64_t product = 1;
1517  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1518  for (int64_t d : grp) {
1519  if (d != dim) {
1520  assert(!resultType.isDynamicDim(d) && "expected static dim");
1521  product *= resultType.getDimSize(d);
1522  }
1523  }
1524 
1525  // result dim size = src dim size / (product(other dims in reassoc group))
1526  Value srcDimSz =
1527  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1528  AffineExpr expr;
1529  bindSymbols(dimOp.getContext(), expr);
1530  rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, expr.floorDiv(product),
1531  srcDimSz);
1532  return success();
1533  }
1534 };
1535 
1536 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1538 
1539  LogicalResult matchAndRewrite(DimOp dimOp,
1540  PatternRewriter &rewriter) const override {
1541  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1542  if (!collapseShapeOp)
1543  return failure();
1544 
1545  // Only constant dimension values are supported.
1546  Optional<int64_t> dim = dimOp.getConstantIndex();
1547  if (!dim.has_value())
1548  return failure();
1549 
1550  // Skip static dims. These are folded to constant ops.
1551  TensorType resultType = collapseShapeOp.getResultType();
1552  if (!resultType.isDynamicDim(*dim))
1553  return failure();
1554 
1555  // Get reassociation group of the result dimension.
1556  ReassociationIndices group =
1557  collapseShapeOp.getReassociationIndices()[*dim];
1558 
1559  // result dim size = product(dims in reassoc group)
1560  SmallVector<Value> srcDimSizes;
1562  AffineExpr product;
1563  for (const auto &it : llvm::enumerate(group)) {
1564  srcDimSizes.push_back(rewriter.create<DimOp>(
1565  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1566  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1567  product = product ? product * syms.back() : syms.back();
1568  }
1569  rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, product, srcDimSizes);
1570  return success();
1571  }
1572 };
1573 } // namespace
1574 
1575 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1576  MLIRContext *context) {
1579  FoldReshapeWithConstant<ExpandShapeOp>,
1580  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1581  FoldDimOfCollapseShape>(context);
1582 }
1583 
1584 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1585  MLIRContext *context) {
1586  results
1589  FoldReshapeWithConstant<CollapseShapeOp>,
1590  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1591  context);
1592 }
1593 
1594 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1595  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1596 }
1597 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1598  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1599 }
1600 
1601 //===----------------------------------------------------------------------===//
1602 // ExtractSliceOp
1603 //===----------------------------------------------------------------------===//
1604 
1605 void ExtractSliceOp::getAsmResultNames(
1606  function_ref<void(Value, StringRef)> setNameFn) {
1607  setNameFn(getResult(), "extracted_slice");
1608 }
1609 
1610 /// An extract_slice result type can be inferred, when it is not
1611 /// rank-reduced, from the source type and the static representation of
1612 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
1613 RankedTensorType ExtractSliceOp::inferResultType(
1614  ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
1615  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
1616  // An extract_slice op may specify only a leading subset of offset/sizes/
1617  // strides in which case we complete with offset=0, sizes from memref type
1618  // and strides=1.
1619  assert(static_cast<int64_t>(staticSizes.size()) ==
1620  sourceShapedTensorType.getRank() &&
1621  "unexpected staticSizes not equal to rank of source");
1622  return RankedTensorType::get(staticSizes,
1623  sourceShapedTensorType.getElementType());
1624 }
1625 
1626 RankedTensorType ExtractSliceOp::inferResultType(
1627  ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
1629  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1630  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1631  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1632  ShapedType::kDynamic);
1633  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1634  ShapedType::kDynamic);
1635  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1636  ShapedType::kDynamic);
1637  return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
1638  staticSizes, staticStrides);
1639 }
1640 
1641 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
1642 /// number of sizes), drop as many size 1 as needed to produce an inferred
1643 /// type with the desired rank.
1644 ///
1645 /// Note that there may be multiple ways to compute this rank-reduced type:
1646 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
1647 ///
1648 /// To disambiguate, this function always drops the first 1 sizes occurrences.
1649 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1650  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1651  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1652  ArrayRef<int64_t> strides) {
1653  // Type inferred in the absence of rank-reducing behavior.
1654  auto inferredType =
1655  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1656  .cast<RankedTensorType>();
1657  int rankDiff = inferredType.getRank() - desiredResultRank;
1658  if (rankDiff > 0) {
1659  auto shape = inferredType.getShape();
1660  llvm::SmallBitVector dimsToProject =
1661  getPositionsOfShapeOne(rankDiff, shape);
1662  SmallVector<int64_t> projectedShape;
1663  // Best effort rank-reducing: drop 1s in order.
1664  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1665  if (!dimsToProject.test(pos))
1666  projectedShape.push_back(shape[pos]);
1667  inferredType =
1668  RankedTensorType::get(projectedShape, inferredType.getElementType());
1669  }
1670  return inferredType;
1671 }
1672 
1673 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1674  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1676  ArrayRef<OpFoldResult> strides) {
1677  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1678  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1679  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1680  ShapedType::kDynamic);
1681  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1682  ShapedType::kDynamic);
1683  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1684  ShapedType::kDynamic);
1685  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1686  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1687  staticStrides);
1688 }
1689 
1690 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1691 /// result type. If the type passed is nullptr, it is inferred.
1692 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1693  RankedTensorType resultType, Value source,
1694  ArrayRef<OpFoldResult> offsets,
1695  ArrayRef<OpFoldResult> sizes,
1696  ArrayRef<OpFoldResult> strides,
1697  ArrayRef<NamedAttribute> attrs) {
1698  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1699  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1700  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1701  ShapedType::kDynamic);
1702  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1703  ShapedType::kDynamic);
1704  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1705  ShapedType::kDynamic);
1706  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1707  // Structuring implementation this way avoids duplication between builders.
1708  if (!resultType) {
1709  resultType =
1710  ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1711  staticSizes, staticStrides)
1712  .cast<RankedTensorType>();
1713  }
1714  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1715  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1716  b.getDenseI64ArrayAttr(staticSizes),
1717  b.getDenseI64ArrayAttr(staticStrides));
1718  result.addAttributes(attrs);
1719 }
1720 
1721 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1722 /// result type.
1723 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1724  ArrayRef<OpFoldResult> offsets,
1725  ArrayRef<OpFoldResult> sizes,
1726  ArrayRef<OpFoldResult> strides,
1727  ArrayRef<NamedAttribute> attrs) {
1728  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1729 }
1730 
1731 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
1732 /// a Range vector.
1733 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1734  ArrayRef<Range> ranges,
1735  ArrayRef<NamedAttribute> attrs) {
1736  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
1737  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1738 }
1739 
1740 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
1741 /// the type passed is nullptr, it is inferred.
1742 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1743  RankedTensorType resultType, Value source,
1744  ValueRange offsets, ValueRange sizes,
1745  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1746  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1747  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1748  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1749  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1750  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1751  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1752  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1753 }
1754 
1755 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
1756 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1757  ValueRange offsets, ValueRange sizes,
1758  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1759  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1760 }
1761 
1762 template <typename OpTy>
1764  OpTy op, Type expectedType) {
1765  auto memrefType = expectedType.cast<ShapedType>();
1766  switch (result) {
1768  return success();
1770  return op.emitError("expected rank to be smaller or equal to ")
1771  << "the other rank. ";
1773  return op.emitError("expected type to be ")
1774  << expectedType << " or a rank-reduced version. (size mismatch) ";
1776  return op.emitError("expected element type to be ")
1777  << memrefType.getElementType();
1778  default:
1779  llvm_unreachable("unexpected extract_slice op verification result");
1780  }
1781 }
1782 
1783 /// Verifier for ExtractSliceOp.
1785  // Verify result type against inferred type.
1786  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1787  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1788  SliceVerificationResult result = isRankReducedType(expectedType, getType());
1789  return produceSliceErrorMsg(result, *this, expectedType);
1790 }
1791 
1792 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1793  ArrayRef<int64_t> resultShape = getType().getShape();
1795  llvm::SmallBitVector droppedDims(mixedSizes.size());
1796  unsigned shapePos = 0;
1797  for (const auto &size : enumerate(mixedSizes)) {
1798  Optional<int64_t> sizeVal = getConstantIntValue(size.value());
1799  // If the size is not 1, or if the current matched dimension of the result
1800  // is the same static shape as the size value (which is 1), then the
1801  // dimension is preserved.
1802  if (!sizeVal || *sizeVal != 1 ||
1803  (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1804  shapePos++;
1805  continue;
1806  }
1807  droppedDims.set(size.index());
1808  }
1809  return droppedDims;
1810 }
1811 
1812 LogicalResult ExtractSliceOp::reifyResultShapes(
1813  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1814  reifiedReturnShapes.resize(1);
1815  reifiedReturnShapes[0].reserve(getType().getRank());
1817  llvm::SmallBitVector droppedDims = getDroppedDims();
1818  Location loc = getLoc();
1819  for (const auto &size : enumerate(mixedSizes)) {
1820  if (droppedDims.test(size.index()))
1821  continue;
1822  if (auto attr = size.value().dyn_cast<Attribute>()) {
1823  reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
1824  loc, attr.cast<IntegerAttr>().getInt()));
1825  continue;
1826  }
1827  reifiedReturnShapes[0].push_back(size.value().get<Value>());
1828  }
1829  return success();
1830 }
1831 
1832 namespace {
1833 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1834 /// This essentially pushes memref_cast past its consuming slice when
1835 /// `canFoldIntoConsumerOp` is true.
1836 ///
1837 /// Example:
1838 /// ```
1839 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1840 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1841 /// tensor<3x4xf32>
1842 /// ```
1843 /// is rewritten into:
1844 /// ```
1845 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1846 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1847 /// ```
1848 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1849 public:
1851 
1852  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1853  PatternRewriter &rewriter) const override {
1854  // Any constant operand, just return to let the constant folder kick in.
1855  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1856  return matchPattern(operand, matchConstantIndex());
1857  }))
1858  return failure();
1859 
1860  auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1861  if (!castOp)
1862  return failure();
1863 
1864  if (!canFoldIntoConsumerOp(castOp))
1865  return failure();
1866 
1867  /// Deduce the type of the result to use for the canonicalized operation.
1868  RankedTensorType resultType =
1869  ExtractSliceOp::inferCanonicalRankReducedResultType(
1870  sliceOp.getType().getRank(), sliceOp.getSourceType(),
1871  sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1872  sliceOp.getMixedStrides());
1873  Value newSlice = rewriter.create<ExtractSliceOp>(
1874  sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
1875  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1876  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1877  rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1878  newSlice);
1879  return success();
1880  }
1881 };
1882 
1883 /// Slice elements from `values` into `outValues`. `counts` represents the
1884 /// numbers of elements to stride in the original values for each dimension.
1885 /// The output values can be used to construct a DenseElementsAttr.
1886 template <typename IterTy, typename ElemTy>
1887 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1888  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1889  ArrayRef<int64_t> strides,
1890  llvm::SmallVectorImpl<ElemTy> *outValues) {
1891  assert(offsets.size() == sizes.size());
1892  assert(offsets.size() == strides.size());
1893  if (offsets.empty())
1894  return;
1895 
1896  int64_t offset = offsets.front();
1897  int64_t size = sizes.front();
1898  int64_t stride = strides.front();
1899  if (offsets.size() == 1) {
1900  for (int64_t i = 0; i < size; ++i, offset += stride)
1901  outValues->push_back(*(values + offset));
1902 
1903  return;
1904  }
1905 
1906  for (int64_t i = 0; i < size; ++i, offset += stride) {
1907  auto begin = values + offset * counts.front();
1908  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1909  offsets.drop_front(), sizes.drop_front(),
1910  strides.drop_front(), outValues);
1911  }
1912 }
1913 
1914 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
1915 /// folded operation might introduce more constant data; Users can control
1916 /// their heuristics by the control function.
1917 class ConstantOpExtractSliceFolder final
1918  : public OpRewritePattern<ExtractSliceOp> {
1919 public:
1921 
1922  ConstantOpExtractSliceFolder(MLIRContext *context,
1924  : OpRewritePattern<ExtractSliceOp>(context),
1925  controlFn(std::move(controlFn)) {}
1926 
1927  LogicalResult matchAndRewrite(ExtractSliceOp op,
1928  PatternRewriter &rewriter) const override {
1929  DenseElementsAttr attr;
1930  if (!matchPattern(op.getSource(), m_Constant(&attr)))
1931  return failure();
1932 
1933  // A constant splat is handled by fold().
1934  if (attr.isSplat())
1935  return failure();
1936 
1937  // Dynamic result shape is not supported.
1938  auto sourceType = op.getSource().getType().cast<ShapedType>();
1939  auto resultType = op.getResult().getType().cast<ShapedType>();
1940  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1941  return failure();
1942 
1943  // Customized control over the folding.
1944  if (!controlFn(op))
1945  return failure();
1946 
1947  int64_t count = sourceType.getNumElements();
1948  if (count == 0)
1949  return failure();
1950 
1951  // Check if there are any dynamic parts, which are not supported.
1952  auto offsets = op.getStaticOffsets();
1953  if (llvm::is_contained(offsets, ShapedType::kDynamic))
1954  return failure();
1955  auto sizes = op.getStaticSizes();
1956  if (llvm::is_contained(sizes, ShapedType::kDynamic))
1957  return failure();
1958  auto strides = op.getStaticStrides();
1959  if (llvm::is_contained(strides, ShapedType::kDynamic))
1960  return failure();
1961 
1962  // Compute the stride for each dimension.
1963  SmallVector<int64_t> counts;
1964  ArrayRef<int64_t> shape = sourceType.getShape();
1965  counts.reserve(shape.size());
1966  for (int64_t v : shape) {
1967  count = count / v;
1968  counts.push_back(count);
1969  }
1970 
1971  // New attribute constructed by the sliced values.
1972  DenseElementsAttr newAttr;
1973 
1974  if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
1975  SmallVector<APInt> outValues;
1976  outValues.reserve(sourceType.getNumElements());
1977  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1978  elems.begin(), counts, offsets, sizes, strides, &outValues);
1979  newAttr = DenseElementsAttr::get(resultType, outValues);
1980  } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
1981  SmallVector<APFloat> outValues;
1982  outValues.reserve(sourceType.getNumElements());
1983  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1984  elems.begin(), counts, offsets, sizes, strides, &outValues);
1985  newAttr = DenseElementsAttr::get(resultType, outValues);
1986  }
1987 
1988  if (newAttr) {
1989  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
1990  return success();
1991  }
1992 
1993  return failure();
1994  }
1995 
1996 private:
1997  /// This additionally controls whether the fold happens or not. Users can
1998  /// impose their heuristics in the function.
2000 };
2001 
2002 } // namespace
2003 
2005  RewritePatternSet &patterns,
2006  const ControlConstantExtractSliceFusionFn &controlFn) {
2007  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2008 }
2009 
2010 /// Return the canonical type of the result of an extract_slice op.
2012  RankedTensorType operator()(ExtractSliceOp op,
2013  ArrayRef<OpFoldResult> mixedOffsets,
2014  ArrayRef<OpFoldResult> mixedSizes,
2015  ArrayRef<OpFoldResult> mixedStrides) {
2016  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2017  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2018  mixedStrides);
2019  }
2020 };
2021 
2022 /// A canonicalizer wrapper to replace ExtractSliceOps.
2024  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2025  ExtractSliceOp newOp) {
2026  Value replacement = newOp.getResult();
2027  if (replacement.getType() != op.getType())
2028  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2029  replacement);
2030  rewriter.replaceOp(op, replacement);
2031  }
2032 };
2033 
2034 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2035  MLIRContext *context) {
2036  results.add<
2039  ExtractSliceOpCastFolder>(context);
2040 }
2041 
2042 //
2043 static LogicalResult
2044 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2045  ShapedType shapedType) {
2046  OpBuilder b(op.getContext());
2047  for (OpFoldResult ofr : op.getMixedOffsets())
2048  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2049  return failure();
2050  // Rank-reducing noops only need to inspect the leading dimensions:
2051  // llvm::zip is appropriate.
2052  auto shape = shapedType.getShape();
2053  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2054  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2055  return failure();
2056  for (OpFoldResult ofr : op.getMixedStrides())
2057  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2058  return failure();
2059  return success();
2060 }
2061 
2062 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2063 /// slice, we can return the InsertSliceOp's source directly.
2064 // TODO: This only checks the immediate producer; extend to go up the
2065 // insert/extract chain if the slices are disjoint.
2066 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2067  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2068 
2069  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2070  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2071  insertOp.isSameAs(extractOp, isSame))
2072  return insertOp.getSource();
2073 
2074  return {};
2075 }
2076 
2077 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
2078  if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
2079  auto resultType = getResult().getType().cast<ShapedType>();
2080  if (resultType.hasStaticShape())
2081  return splat.resizeSplat(resultType);
2082  }
2083  if (getSourceType() == getType() &&
2085  return this->getSource();
2086  if (Value slice = foldExtractAfterInsertSlice(*this))
2087  return slice;
2088 
2089  return OpFoldResult();
2090 }
2091 
2093  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2094  auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
2095  unsigned rank = rankedTensorType.getRank();
2096  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2097  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2098  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2099  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2100  offsets, sizes, strides);
2101 }
2102 
2103 //===----------------------------------------------------------------------===//
2104 // InsertSliceOp
2105 //===----------------------------------------------------------------------===//
2106 
2107 void InsertSliceOp::getAsmResultNames(
2108  function_ref<void(Value, StringRef)> setNameFn) {
2109  setNameFn(getResult(), "inserted_slice");
2110 }
2111 
2112 // Build a InsertSliceOp with mixed static and dynamic entries.
2113 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2114  Value dest, ArrayRef<OpFoldResult> offsets,
2115  ArrayRef<OpFoldResult> sizes,
2116  ArrayRef<OpFoldResult> strides,
2117  ArrayRef<NamedAttribute> attrs) {
2118  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2119  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2120  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2121  ShapedType::kDynamic);
2122  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2123  ShapedType::kDynamic);
2124  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2125  ShapedType::kDynamic);
2126  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2127  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2128  b.getDenseI64ArrayAttr(staticSizes),
2129  b.getDenseI64ArrayAttr(staticStrides));
2130  result.addAttributes(attrs);
2131 }
2132 
2133 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2134 /// Range vector.
2135 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2136  Value dest, ArrayRef<Range> ranges,
2137  ArrayRef<NamedAttribute> attrs) {
2138  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2139  build(b, result, source, dest, offsets, sizes, strides, attrs);
2140 }
2141 
2142 // Build a InsertSliceOp with dynamic entries.
2143 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2144  Value dest, ValueRange offsets, ValueRange sizes,
2145  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2146  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2147  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2148  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2149  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2150  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2151  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2152  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2153 }
2154 
2155 /// Rank-reducing type verification for both InsertSliceOp and
2156 /// ParallelInsertSliceOp.
2158  ShapedType srcType, ShapedType dstType, ArrayRef<int64_t> staticOffsets,
2159  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
2160  ShapedType *expectedType = nullptr) {
2161  // insert_slice is the inverse of extract_slice, use the same type
2162  // inference.
2163  RankedTensorType expected = ExtractSliceOp::inferResultType(
2164  dstType, staticOffsets, staticSizes, staticStrides);
2165  if (expectedType)
2166  *expectedType = expected;
2167  return isRankReducedType(expected, srcType);
2168 }
2169 
2170 /// Verifier for InsertSliceOp.
2172  ShapedType expectedType;
2173  SliceVerificationResult result =
2174  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2175  getStaticSizes(), getStaticStrides(), &expectedType);
2176  return produceSliceErrorMsg(result, *this, expectedType);
2177 }
2178 
2179 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2180 /// can mutate the second InsertSliceOp's destination to the first one's.
2181 ///
2182 /// Example:
2183 ///
2184 /// ```mlir
2185 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2186 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2187 /// ```
2188 ///
2189 /// folds into:
2190 ///
2191 /// ```mlir
2192 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2193 /// ```
2194 ///
2195 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2196 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2197  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2198 
2199  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2200  if (!prevInsertOp ||
2201  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2202  !prevInsertOp.isSameAs(insertOp, isSame))
2203  return failure();
2204 
2205  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2206  return success();
2207 }
2208 
2209 /// Folds round-trip extract/insert slice op pairs.
2210 /// Example:
2211 /// ```mlir
2212 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2213 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2214 /// ```
2215 /// can be folded into %val.
2216 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2217  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2218 
2219  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2220  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2221  !extractOp.isSameAs(insertOp, isSame))
2222  return nullptr;
2223 
2224  return extractOp.getSource();
2225 }
2226 
2227 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
2228  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2229  getSourceType() == getType() &&
2231  return this->getSource();
2233  return getResult();
2234  if (auto result = foldInsertAfterExtractSlice(*this))
2235  return result;
2236  return OpFoldResult();
2237 }
2238 
2239 LogicalResult InsertSliceOp::reifyResultShapes(
2240  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2241  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
2242  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
2243  reifiedReturnShapes[0][dim] =
2244  builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
2245  }
2246  return success();
2247 }
2248 
2249 namespace {
2250 /// Pattern to rewrite a insert_slice op with constant arguments.
2251 ///
2252 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2253 template <typename InsertOpTy>
2254 class InsertSliceOpConstantArgumentFolder final
2255  : public OpRewritePattern<InsertOpTy> {
2256 public:
2258 
2259  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2260  PatternRewriter &rewriter) const override {
2261  // No constant operand, just return.
2262  if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
2263  return matchPattern(operand, matchConstantIndex());
2264  }))
2265  return failure();
2266 
2267  // At least one of offsets/sizes/strides is a new constant.
2268  // Form the new list of operands and constant attributes from the
2269  // existing.
2270  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2271  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2272  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2273  canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic);
2274  canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
2275  canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic);
2276 
2277  // Create the new op in canonical form.
2278  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2279  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2280  mixedOffsets, mixedSizes, mixedStrides);
2281  Value toInsert = insertSliceOp.getSource();
2282  if (sourceType != insertSliceOp.getSourceType()) {
2283  OpBuilder::InsertionGuard g(rewriter);
2284  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2285  // is the the insertion point is just before the ParallelCombiningOp in
2286  // the parallel case.
2288  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2289  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2290  sourceType, toInsert);
2291  }
2292  rewriter.replaceOpWithNewOp<InsertOpTy>(
2293  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2294  mixedSizes, mixedStrides);
2295  return success();
2296  }
2297 };
2298 
2299 /// Fold tensor_casts with insert_slice operations. If the source or
2300 /// destination tensor is a tensor_cast that removes static type information,
2301 /// the cast is folded into the insert_slice operation. E.g.:
2302 ///
2303 /// ```mlir
2304 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2305 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2306 /// ```
2307 ///
2308 /// folds into:
2309 ///
2310 /// ```mlir
2311 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2312 /// ```
2313 ///
2314 /// Note: When folding a cast on the destination tensor, the result of the
2315 /// insert_slice operation is casted to ensure that the type of the result did
2316 /// not change.
2317 ///
2318 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2319 template <typename InsertOpTy>
2320 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2322 
2323  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2324  PatternRewriter &rewriter) const override {
2325  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2326  return matchPattern(operand, matchConstantIndex());
2327  }))
2328  return failure();
2329 
2330  auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
2331  auto castOp = v.getDefiningOp<tensor::CastOp>();
2332  if (!castOp || !canFoldIntoConsumerOp(castOp))
2333  return llvm::None;
2334  return castOp.getSource();
2335  };
2336  Optional<Value> sourceCastSource =
2337  getSourceOfCastOp(insertSliceOp.getSource());
2338  Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
2339  if (!sourceCastSource && !destCastSource)
2340  return failure();
2341 
2342  auto src =
2343  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2344  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2345  auto srcType = src.getType().template cast<ShapedType>();
2346  auto dstType = dst.getType().template cast<ShapedType>();
2347  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2348  insertSliceOp.getStaticSizes(),
2349  insertSliceOp.getStaticStrides()) !=
2351  return failure();
2352 
2353  Operation *replacement = rewriter.create<InsertOpTy>(
2354  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2355  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2356 
2357  // In the parallel case there is no result and so nothing to cast.
2358  bool isParallelInsert =
2360  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2361  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2362  insertSliceOp.getDestType(),
2363  replacement->getResult(0));
2364  }
2365  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2366  return success();
2367  }
2368 };
2369 
2370 /// If additional static type information can be deduced from a insert_slice's
2371 /// size operands, insert an explicit cast of the op's source operand. This
2372 /// enables other canonicalization patterns that are matching for tensor_cast
2373 /// ops such as `ForOpTensorCastFolder` in SCF.
2374 ///
2375 /// Example:
2376 ///
2377 /// ```mlir
2378 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2379 /// : tensor<?x?xf32> into ...
2380 /// ```
2381 ///
2382 /// folds into:
2383 ///
2384 /// ```mlir
2385 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2386 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2387 /// : tensor<64x64xf32> into ...
2388 /// ```
2389 ///
2390 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2391 template <typename InsertOpTy>
2392 struct InsertSliceOpSourceCastInserter final
2393  : public OpRewritePattern<InsertOpTy> {
2395 
2396  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2397  PatternRewriter &rewriter) const override {
2398  RankedTensorType srcType = insertSliceOp.getSourceType();
2399  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2400  return failure();
2401  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2402  srcType.getShape().end());
2403  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2404  if (Optional<int64_t> constInt =
2405  getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
2406  newSrcShape[i] = *constInt;
2407  }
2408 
2409  RankedTensorType newSrcType =
2410  RankedTensorType::get(newSrcShape, srcType.getElementType());
2411  if (srcType == newSrcType ||
2412  !preservesStaticInformation(srcType, newSrcType) ||
2413  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2414  return failure();
2415 
2416  // newSrcType is:
2417  // 1) Different from srcType.
2418  // 2) "More static" than srcType.
2419  // 3) Cast-compatible with srcType.
2420  // Insert the cast.
2421  OpBuilder::InsertionGuard g(rewriter);
2422  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2423  // the the insertion point is just before the ParallelCombiningOp in the
2424  // parallel case.
2426  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2427  Value cast = rewriter.create<tensor::CastOp>(
2428  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2429  rewriter.replaceOpWithNewOp<InsertOpTy>(
2430  insertSliceOp, cast, insertSliceOp.getDest(),
2431  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2432  insertSliceOp.getMixedStrides());
2433  cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
2434  return success();
2435  }
2436 };
2437 } // namespace
2438 
2439 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2440  MLIRContext *context) {
2441  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2442  InsertSliceOpCastFolder<InsertSliceOp>,
2443  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2444 }
2445 
2447  Location loc,
2448  Value tensor,
2449  Value dest) {
2450  auto rankedTensorType = dest.getType().cast<RankedTensorType>();
2451  unsigned rank = rankedTensorType.getRank();
2452  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2453  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2454  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2455  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2456  sizes, strides);
2457 }
2458 
2459 //===----------------------------------------------------------------------===//
2460 // PadOp
2461 //===----------------------------------------------------------------------===//
2462 
2463 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2464  setNameFn(getResult(), "padded");
2465 }
2466 
2467 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2468 // supports optional types.
2469 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2470  Type typeToInfer, Type typeToInferFrom) {}
2471 
2474  Type &typeToInfer, Type typeToInferFrom) {
2475  if (optOperand)
2476  typeToInfer = typeToInferFrom;
2477  return success();
2478 }
2479 
2481  auto sourceType = getSource().getType().cast<RankedTensorType>();
2482  auto resultType = getResult().getType().cast<RankedTensorType>();
2483  auto expectedType =
2484  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2485  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2486  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2487  continue;
2488  if (expectedType.isDynamicDim(i))
2489  continue;
2490  return emitError("specified type ")
2491  << resultType << " does not match the inferred type "
2492  << expectedType;
2493  }
2494 
2495  return success();
2496 }
2497 
2498 LogicalResult PadOp::verifyRegions() {
2499  auto &region = getRegion();
2500  unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
2501  Block &block = region.front();
2502  if (block.getNumArguments() != rank)
2503  return emitError("expected the block to have ") << rank << " arguments";
2504 
2505  // Note: the number and type of yield values are checked in the YieldOp.
2506  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2507  if (!en.value().isIndex())
2508  return emitOpError("expected block argument ")
2509  << (en.index() + 1) << " to be an index";
2510  }
2511 
2512  // Ensure that the region yields an element of the right type.
2513  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2514  if (yieldOp.getValue().getType() !=
2515  getType().cast<ShapedType>().getElementType())
2516  return emitOpError("expected yield type to match shape element type");
2517 
2518  return success();
2519 }
2520 
2521 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2522  ArrayRef<int64_t> staticLow,
2523  ArrayRef<int64_t> staticHigh,
2524  ArrayRef<int64_t> resultShape) {
2525  unsigned rank = sourceType.getRank();
2526  assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
2527  assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
2528  assert((resultShape.empty() || resultShape.size() == rank) &&
2529  "unexpected resultShape size mismatch");
2530 
2531  SmallVector<int64_t, 4> inferredShape;
2532  for (auto i : llvm::seq<unsigned>(0, rank)) {
2533  if (sourceType.isDynamicDim(i) ||
2534  staticLow[i] == ShapedType::kDynamic ||
2535  staticHigh[i] == ShapedType::kDynamic) {
2536  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2537  : resultShape[i]);
2538  } else {
2539  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2540  assert((resultShape.empty() || size == resultShape[i] ||
2541  resultShape[i] == ShapedType::kDynamic) &&
2542  "mismatch between inferred shape and result shape");
2543  inferredShape.push_back(size);
2544  }
2545  }
2546 
2547  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2548 }
2549 
2550 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
2551  ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
2552  ValueRange low, ValueRange high, bool nofold,
2553  ArrayRef<NamedAttribute> attrs) {
2554  auto sourceType = source.getType().cast<RankedTensorType>();
2555  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
2556  build(b, result, resultType, source, low, high,
2557  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2558  nofold ? b.getUnitAttr() : UnitAttr());
2559  result.addAttributes(attrs);
2560 }
2561 
2562 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
2563  ValueRange low, ValueRange high, bool nofold,
2564  ArrayRef<NamedAttribute> attrs) {
2565  auto sourceType = source.getType().cast<RankedTensorType>();
2566  unsigned rank = sourceType.getRank();
2567  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2568  build(b, result, source, staticVector, staticVector, low, high, nofold,
2569  attrs);
2570 }
2571 
2572 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2573  Value source, ArrayRef<OpFoldResult> low,
2574  ArrayRef<OpFoldResult> high, bool nofold,
2575  ArrayRef<NamedAttribute> attrs) {
2576  auto sourceType = source.getType().cast<RankedTensorType>();
2577  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2578  SmallVector<int64_t, 4> staticLow, staticHigh;
2579  // staticLow and staticHigh have full information of the padding config.
2580  // This will grow staticLow and staticHigh with 1 value. If the config is
2581  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
2582  // value as well.
2583  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
2584  ShapedType::kDynamic);
2585  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
2586  ShapedType::kDynamic);
2587  if (!resultType) {
2588  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2589  }
2590  assert(resultType.isa<RankedTensorType>());
2591  build(b, result, resultType, source, dynamicLow, dynamicHigh,
2592  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2593  nofold ? b.getUnitAttr() : UnitAttr());
2594  result.addAttributes(attrs);
2595 }
2596 
2597 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2598  Value source, ArrayRef<OpFoldResult> low,
2599  ArrayRef<OpFoldResult> high, Value constantPadValue,
2600  bool nofold, ArrayRef<NamedAttribute> attrs) {
2601  build(b, result, resultType, source, low, high, nofold, attrs);
2602 
2603  // Add a region and a block to yield the pad value.
2604  Region *region = result.regions[0].get();
2605  int sourceRank = source.getType().cast<RankedTensorType>().getRank();
2606  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
2607  SmallVector<Location> blockArgLocs(sourceRank, result.location);
2608 
2609  // `builder.createBlock` changes the insertion point within the block. Create
2610  // a guard to reset the insertion point of the builder after it is destroyed.
2611  OpBuilder::InsertionGuard guard(b);
2612  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
2613  b.create<tensor::YieldOp>(result.location, constantPadValue);
2614 }
2615 
2616 llvm::SmallBitVector PadOp::getPaddedDims() {
2617  llvm::SmallBitVector paddedDims(getSourceType().getRank());
2618  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
2619  for (const auto &en : enumerate(paddingWidths))
2620  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
2621  paddedDims.set(en.index());
2622  };
2623  extractPaddedDims(getMixedLowPad());
2624  extractPaddedDims(getMixedHighPad());
2625  return paddedDims;
2626 }
2627 
2628 namespace {
2629 // Folds tensor.pad when padding is static zeros and the attribute
2630 // doesn't request otherwise.
2631 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
2633 
2634  LogicalResult matchAndRewrite(PadOp padTensorOp,
2635  PatternRewriter &rewriter) const override {
2636  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2637  return failure();
2638  if (padTensorOp.getNofold())
2639  return failure();
2640  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2641  padTensorOp, padTensorOp.getResult().getType(),
2642  padTensorOp.getSource());
2643  return success();
2644  }
2645 };
2646 
2647 // Fold CastOp into PadOp when adding static information.
2648 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
2650 
2651  LogicalResult matchAndRewrite(PadOp padTensorOp,
2652  PatternRewriter &rewriter) const override {
2653  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2654  if (!tensor::canFoldIntoConsumerOp(castOp))
2655  return failure();
2656 
2657  auto newResultType = PadOp::inferResultType(
2658  castOp.getSource().getType().cast<RankedTensorType>(),
2659  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2660  padTensorOp.getResultType().getShape());
2661 
2662  if (newResultType == padTensorOp.getResultType()) {
2663  rewriter.updateRootInPlace(padTensorOp, [&]() {
2664  padTensorOp.getSourceMutable().assign(castOp.getSource());
2665  });
2666  } else {
2667  auto newOp = rewriter.create<PadOp>(
2668  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2669  padTensorOp.getLow(), padTensorOp.getHigh(),
2670  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2671  padTensorOp.getNofold());
2672  BlockAndValueMapping mapper;
2673  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2674 
2675  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2676  padTensorOp, padTensorOp.getResultType(), newOp);
2677  }
2678  return success();
2679  }
2680 };
2681 
2682 // Fold CastOp using the result of PadOp back into the latter if it adds
2683 // static information.
2684 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2686 
2687  LogicalResult matchAndRewrite(PadOp padTensorOp,
2688  PatternRewriter &rewriter) const override {
2689  if (!padTensorOp.getResult().hasOneUse())
2690  return failure();
2691  auto tensorCastOp =
2692  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2693  if (!tensorCastOp)
2694  return failure();
2695  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2696  tensorCastOp.getDest().getType()))
2697  return failure();
2698 
2699  auto replacementOp = rewriter.create<PadOp>(
2700  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2701  padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
2702  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2703  padTensorOp.getNofold());
2704  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2705 
2706  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2707  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2708  return success();
2709  }
2710 };
2711 
2712 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2713 /// different dimensions. The pattern applies if the following preconditions
2714 /// hold:
2715 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2716 /// 2) the tensor::ExtractSliceOps have only unit-strides,
2717 /// 3) the tensor::PadOps perform only high-padding,
2718 /// 4) the tensor::PadOps have the same constant padding value,
2719 /// 5) the tensor::PadOps do not have common padding dimensions,
2720 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2721 /// zero-offset for every dimension.
2722 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
2723 /// the
2724 /// padded source dimensions.
2725 ///
2726 /// Example:
2727 ///
2728 /// ```mlir
2729 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2730 /// : tensor<64x64xf32> to tensor<?x64xf32>
2731 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2732 /// } : tensor<?x64xf32> to tensor<8x64xf32>
2733 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2734 /// : tensor<8x64xf32> to tensor<8x?xf32>
2735 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2736 /// } : tensor<8x?xf32> to tensor<8x4xf32>
2737 /// ```
2738 ///
2739 /// folds into:
2740 ///
2741 /// ```mlir
2742 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2743 /// : tensor<64x64xf32> to tensor<?x?xf32>
2744 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2745 /// } : tensor<?x?xf32> to tensor<8x4xf32>
2746 /// ```
2747 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2749 
2750  LogicalResult matchAndRewrite(PadOp padOp,
2751  PatternRewriter &rewriter) const override {
2752  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2753  if (!innerSliceOp)
2754  return failure();
2755  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2756  if (!outerPadOp || outerPadOp.getNofold())
2757  return failure();
2758  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2759  if (!outerSliceOp)
2760  return failure();
2761 
2762  // 1) Fail if the chain is rank-reducing.
2763  int64_t rank = padOp.getSourceType().getRank();
2764  if (outerSliceOp.getSourceType().getRank() != rank) {
2765  return rewriter.notifyMatchFailure(padOp,
2766  "cannot fold rank-reducing chain");
2767  }
2768 
2769  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2770  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2771  return rewriter.notifyMatchFailure(
2772  padOp, "cannot fold non-unit stride ExtractSliceOps");
2773  }
2774 
2775  // 3) Fail if the tensor::PadOps have non-zero low padding.
2776  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2777  return rewriter.notifyMatchFailure(padOp,
2778  "cannot fold PadOps with low padding");
2779  }
2780 
2781  // 4) Fail if the tensor::PadOps padding values do not match.
2782  Attribute innerAttr, outerAttr;
2783  Value innerValue = padOp.getConstantPaddingValue();
2784  Value outerValue = outerPadOp.getConstantPaddingValue();
2785  if (!innerValue || !outerValue ||
2786  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2787  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2788  innerAttr != outerAttr) {
2789  return rewriter.notifyMatchFailure(
2790  padOp, "cannot fold PadOps with different padding values");
2791  }
2792 
2793  // 5) Fail if a dimension is padded by both tensor::PadOps.
2794  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2795  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2796  if (innerDims.anyCommon(outerDims)) {
2797  return rewriter.notifyMatchFailure(
2798  padOp, "cannot fold PadOps with common padding dimensions");
2799  }
2800 
2801  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2802  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2803  // for every dimension, and use the offset the other pair. Fail if no
2804  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2805  // exists.
2806  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2807  for (auto &en : enumerate(newOffsets)) {
2808  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2809  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2810  if (!innerDims.test(en.index()) &&
2811  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2812  en.value() = outerOffset;
2813  continue;
2814  }
2815  if (!outerDims.test(en.index()) &&
2816  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2817  en.value() = innerOffset;
2818  continue;
2819  }
2820  return rewriter.notifyMatchFailure(
2821  padOp, "cannot find zero-offset and zero-padding pair");
2822  }
2823 
2824  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
2825  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
2826  // outer tensor::PadOp and fail if the size of the inner
2827  // tensor::ExtractSliceOp does not match the size of the padded dimension.
2828  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
2829  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2830  for (auto &en : enumerate(newSizes)) {
2831  if (!outerDims.test(en.index()))
2832  continue;
2833  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2834  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2835  assert(!ShapedType::isDynamic(sourceSize) &&
2836  "expected padded dimension to have a static size");
2837  if (getConstantIntValue(sliceSize) != sourceSize) {
2838  return rewriter.notifyMatchFailure(
2839  padOp, "cannot fold since the inner ExtractSliceOp size does not "
2840  "match the size of the outer padding");
2841  }
2842  en.value() = outerSliceOp.getMixedSizes()[en.index()];
2843  }
2844 
2845  // Combine the high paddings of the two tensor::PadOps.
2846  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2847  for (auto &en : enumerate(newHighPad)) {
2848  if (innerDims.test(en.index()))
2849  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2850  if (outerDims.test(en.index()))
2851  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2852  }
2853 
2854  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
2855  // the two paddings in one step.
2856  auto newSliceOp = rewriter.create<ExtractSliceOp>(
2857  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2858  innerSliceOp.getMixedStrides());
2859  auto newPadOp = rewriter.create<PadOp>(
2860  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2861  padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2862  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2863  newPadOp.getRegion().begin());
2864  rewriter.replaceOp(padOp, newPadOp.getResult());
2865  return success();
2866  }
2867 };
2868 
2869 } // namespace
2870 
2871 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2872  MLIRContext *context) {
2873  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2874  FoldOrthogonalPaddings>(context);
2875 }
2876 
2877 /// Return the padding value of the PadOp if it constant. In this context,
2878 /// "constant" means an actual constant or "defined outside of the block".
2879 ///
2880 /// Values are considered constant in three cases:
2881 /// - A ConstantLike value.
2882 /// - A basic block argument from a different block.
2883 /// - A value defined outside of the block.
2884 ///
2885 /// If the padding value is not constant, an empty Value is returned.
2886 Value PadOp::getConstantPaddingValue() {
2887  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2888  if (!yieldOp)
2889  return {};
2890  Value padValue = yieldOp.getValue();
2891  // Check if yield value is a constant.
2892  if (matchPattern(padValue, m_Constant()))
2893  return padValue;
2894  // Check if yield value is defined inside the PadOp block.
2895  if (padValue.getParentBlock() == &getRegion().front())
2896  return {};
2897  // Else: Yield value defined outside of the PadOp block.
2898  return padValue;
2899 }
2900 
2901 OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2902  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
2903  !getNofold())
2904  return getSource();
2905  return {};
2906 }
2907 
2908 //===----------------------------------------------------------------------===//
2909 // ParallelInsertSliceOp
2910 //===----------------------------------------------------------------------===//
2911 
2912 OpResult ParallelInsertSliceOp::getTiedOpResult() {
2913  ParallelCombiningOpInterface parallelCombiningParent =
2914  getParallelCombiningParent();
2915  for (const auto &it :
2916  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
2917  Operation &nextOp = it.value();
2918  if (&nextOp == getOperation())
2919  return parallelCombiningParent.getParentResult(it.index());
2920  }
2921  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
2922 }
2923 
2924 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
2925 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2926  Value source, Value dest,
2927  ArrayRef<OpFoldResult> offsets,
2928  ArrayRef<OpFoldResult> sizes,
2929  ArrayRef<OpFoldResult> strides,
2930  ArrayRef<NamedAttribute> attrs) {
2931  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2932  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2933  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
2934  ShapedType::kDynamic);
2935  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
2936  ShapedType::kDynamic);
2937  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
2938  ShapedType::kDynamic);
2939  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
2940  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2941  b.getDenseI64ArrayAttr(staticSizes),
2942  b.getDenseI64ArrayAttr(staticStrides));
2943  result.addAttributes(attrs);
2944 }
2945 
2946 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
2947 /// packed into a Range vector.
2948 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2949  Value source, Value dest,
2950  ArrayRef<Range> ranges,
2951  ArrayRef<NamedAttribute> attrs) {
2952  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2953  build(b, result, source, dest, offsets, sizes, strides, attrs);
2954 }
2955 
2956 // Build a ParallelInsertSliceOp with dynamic entries.
2957 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
2958  Value source, Value dest, ValueRange offsets,
2959  ValueRange sizes, ValueRange strides,
2960  ArrayRef<NamedAttribute> attrs) {
2961  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2962  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2963  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2964  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2965  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2966  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2967  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2968 }
2969 
2971  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
2972  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
2973  << *(getOperation()->getParentOp());
2974 
2975  ShapedType expectedType;
2976  SliceVerificationResult result =
2977  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
2978  getStaticSizes(), getStaticStrides(), &expectedType);
2979  return produceSliceErrorMsg(result, *this, expectedType);
2980 }
2981 
2982 void ParallelInsertSliceOp::getCanonicalizationPatterns(
2983  RewritePatternSet &results, MLIRContext *context) {
2984  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
2985  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
2986  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
2987 }
2988 
2989 //===----------------------------------------------------------------------===//
2990 // ScatterOp
2991 //===----------------------------------------------------------------------===//
2992 
2993 void ScatterOp::getAsmResultNames(
2994  function_ref<void(Value, StringRef)> setNameFn) {
2995  setNameFn(getResult(), "scatter");
2996 }
2997 
2999  int64_t destRank = getDestType().getRank();
3000  ArrayRef<int64_t> scatterDims = getScatterDims();
3001  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3002  "scatter", "dest")))
3003  return failure();
3004 
3005  if (!getUnique())
3006  return emitOpError("requires 'unique' attribute to be set");
3007  // TODO: we could also check statically that there are fewer leading index
3008  // tensor dims than the dest dims. If this is not the case, the unique
3009  // attribute cannot be true.
3010 
3011  // Use the GatherOp::inferResultType on the `dest` type and verify the
3012  // expected type matches the source type.
3013  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3014  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3015  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3016  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3017  if (getSourceType() != expectedSourceType &&
3018  getSourceType() != expectedRankReducedSourceType) {
3019  return emitOpError("source type "
3020  "mismatch: "
3021  "expected ")
3022  << expectedSourceType << " or its rank-reduced variant "
3023  << expectedRankReducedSourceType << " (got: " << getSourceType()
3024  << ")";
3025  }
3026 
3027  return success();
3028 }
3029 
3030 //===----------------------------------------------------------------------===//
3031 // SplatOp
3032 //===----------------------------------------------------------------------===//
3033 
3034 void SplatOp::getAsmResultNames(
3035  function_ref<void(Value, StringRef)> setNameFn) {
3036  setNameFn(getResult(), "splat");
3037 }
3038 
3039 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
3040  auto constOperand = operands.front();
3041  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3042  return {};
3043 
3044  // SplatElementsAttr::get treats single value for second arg as being a
3045  // splat.
3046  return SplatElementsAttr::get(getType(), {constOperand});
3047 }
3048 
3049 //===----------------------------------------------------------------------===//
3050 // PackOp/UnPackOp Common
3051 //===----------------------------------------------------------------------===//
3052 
3053 template <typename OpTy>
3054 static LogicalResult
3056  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3058  "applies to only pack or unpack operations");
3059  int64_t destRank = op.getDestRank();
3060  reifiedReturnShapes.resize(1, SmallVector<Value>(destRank));
3061  for (auto dim : llvm::seq<int64_t>(0, destRank)) {
3062  reifiedReturnShapes[0][dim] =
3063  builder.createOrFold<tensor::DimOp>(op.getLoc(), op.getDest(), dim);
3064  }
3065  return success();
3066 }
3067 
3068 template <typename OpTy>
3071  "applies to only pack or unpack operations");
3072  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3073  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3074  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3075  assert(tiles.size() == dimsToTile.size() &&
3076  "tiles must match indices of dimension to block");
3077  // bind the dimension `i` with the tile factor.
3078  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3079  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3080  return dimAndTileMapping;
3081 }
3082 
3083 template <typename OpTy>
3086  "applies to only pack or unpack operations");
3087  Builder builder(op);
3088  SmallVector<OpFoldResult> mixedInnerTiles;
3089  unsigned dynamicValIndex = 0;
3090  for (int64_t staticTile : op.getStaticInnerTiles()) {
3091  if (!ShapedType::isDynamic(staticTile))
3092  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3093  else
3094  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3095  }
3096  return mixedInnerTiles;
3097 }
3098 
3099 template <typename OpTy>
3102  "applies to only pack or unpack operations");
3103  SmallVector<Value> dynamicTiles;
3104  SmallVector<int64_t> staticTiles;
3105  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles,
3106  ShapedType::kDynamic);
3107  return staticTiles;
3108 }
3109 
3110 /// Returns true if `dimsPos` is invalid. It is invalid when:
3111 /// a) It contains duplicate.
3112 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3113 /// c) The number of elements in `dimsPos` is > than `rank`.
3115  size_t rank) {
3116  size_t dimsPosSize = dimsPos.size();
3117  if (dimsPosSize > rank)
3118  return true;
3119  DenseSet<int64_t> uniqued;
3120  for (int64_t dim : dimsPos)
3121  uniqued.insert(dim);
3122  if (dimsPosSize != uniqued.size())
3123  return true;
3124  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3125  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3126  });
3127 }
3128 
3129 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3130 /// of the `limitShape`.
3131 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3132  ArrayRef<int64_t> limitShape) {
3133  assert(
3134  sourceShape.size() == limitShape.size() &&
3135  "expected source shape rank, and limit of the shape to have same rank");
3136  return llvm::all_of(
3137  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3138  int64_t sourceExtent = std::get<0>(it);
3139  int64_t limit = std::get<1>(it);
3140  return ShapedType::isDynamic(sourceExtent) ||
3141  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3142  });
3143 }
3144 
3145 template <typename OpTy>
3148  "applies to only pack or unpack operations");
3149  Operation *op = packOrUnPack.getOperation();
3150 
3151  // Return true if we have a zero-value tile.
3152  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3153  return llvm::any_of(
3154  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3155  };
3156 
3157  // Verify tiles. Do not allow zero tiles.
3158  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3159  if (hasZeros(mixedTiles))
3160  return op->emitError("invalid zero tile factor");
3161 
3162  // Verify inner_dims_pos and outer_dims_perm.
3163  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
3164  ? packOrUnPack.getSourceType()
3165  : packOrUnPack.getDestType();
3166  size_t unpackedRank = unpackedType.getRank();
3167  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3168  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3169  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3170  return op->emitError("invalid inner_dims_pos vector");
3171  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3172  return op->emitError("invalid outer_dims_perm vector");
3173  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3174  return op->emitError("outer_dims_perm must be a permutation or empty");
3175 
3176  // Tiling factors must be less than or equal to the input rank for pack (or
3177  // output rank for unpack), and must match the number of `inner_dims_pos`.
3178  if (mixedTiles.size() > unpackedRank) {
3179  return op->emitError("tiling factors must be less than or equal to the "
3180  "input rank for pack or output rank for unpack");
3181  }
3182  if (mixedTiles.size() != innerDimsPos.size()) {
3183  return op->emitError(
3184  "tiling factors must equal the number of dimensions to tile");
3185  }
3186 
3187  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3188  ? packOrUnPack.getDestType()
3189  : packOrUnPack.getSourceType();
3190  size_t packedRank = packedType.getRank();
3191  // Require output rank to match input rank + number of blocking factors.
3192  if (unpackedRank + mixedTiles.size() != packedRank) {
3193  return op->emitError(
3194  "packed rank must equal unpacked rank + tiling factors");
3195  }
3196 
3197  // Verify result shape is greater than the minimum expected
3198  // by the pack operation, and that the output shape
3199  // represents full tiles.
3200  ShapedType expectedPackedType = PackOp::inferPackedType(
3201  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3202  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3203  return op->emitError("the shape of output is not large enough to hold the "
3204  "packed data. Expected at least ")
3205  << expectedPackedType << ", got " << packedType;
3206  }
3207  if (!llvm::all_of(
3208  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3209  mixedTiles),
3210  [](std::tuple<int64_t, OpFoldResult> it) {
3211  Optional<int64_t> constTileSize =
3212  getConstantIntValue(std::get<1>(it));
3213  int64_t shape = std::get<0>(it);
3214  if (!constTileSize) {
3215  // If specified tile size is dynamic, output shape should
3216  // be dynamic too.
3217  return ShapedType::isDynamic(shape);
3218  } else {
3219  if (ShapedType::isDynamic(shape)) {
3220  // For the shape being dynamic when tile size is
3221  // specified, return true. In canonical form a constant
3222  // tile size should lead to constant shape of the tiled
3223  // dimension, but not needed for verification.
3224  return true;
3225  }
3226  return shape == constTileSize.value();
3227  }
3228  })) {
3229  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3230  "tiled dimension in the packed type");
3231  }
3232  return success();
3233 }
3234 
3235 //===----------------------------------------------------------------------===//
3236 // PackOp
3237 //===----------------------------------------------------------------------===//
3238 
3239 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3240  setNameFn(getResult(), "pack");
3241 }
3242 
3243 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3244  Value dest, ArrayRef<int64_t> innerDimsPos,
3245  ArrayRef<OpFoldResult> innerTiles,
3246  Optional<Value> paddingValue,
3247  ArrayRef<int64_t> outerDimsPerm) {
3248  assert(innerDimsPos.size() == innerTiles.size() &&
3249  "number of tile sizes specified must match the specified number of "
3250  "original dimensions to be tiled");
3251  SmallVector<int64_t> staticTileSizes;
3252  SmallVector<Value> dynamicTileSizes;
3253  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes,
3254  ShapedType::kDynamic);
3255  build(builder, state, dest.getType(), source, dest,
3256  paddingValue ? paddingValue.value() : nullptr,
3257  outerDimsPerm.empty() ? nullptr
3258  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3259  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3260  builder.getDenseI64ArrayAttr(staticTileSizes));
3261 }
3262 
3264 PackOp::reifyResultShapes(OpBuilder &builder,
3265  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3266  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3267 }
3268 
3269 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3270  return getDimAndTileMappingImpl(*this);
3271 }
3272 
3273 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3274  return getMixedTilesImpl(*this);
3275 }
3276 
3277 SmallVector<int64_t> PackOp::getStaticTiles() {
3278  return getStaticTilesImpl(*this);
3279 }
3280 
3281 /// Check if we have enough static information to catch undefined behavior when
3282 /// the tile size does not divide perfectly the dimension of the input tensor.
3283 static bool
3285  DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
3286  int64_t rank = inputShape.size();
3287  for (int64_t dim = 0; dim < rank; dim++) {
3288  if (ShapedType::isDynamic(inputShape[dim]))
3289  continue;
3290  auto it = dimAndTileMapping.find(dim);
3291  if (it == dimAndTileMapping.end())
3292  continue;
3293  Optional<int64_t> constantTile = getConstantIntValue(it->second);
3294  if (!constantTile)
3295  continue;
3296  if (inputShape[dim] % (*constantTile) != 0)
3297  return true;
3298  }
3299  return false;
3300 }
3301 
3304  return failure();
3305 
3306  // Verify padding value, and bail out if the tile does not divide the
3307  // dimension fully. In the case of dynamic tile factors or dimensions, having
3308  // a partial tile is undefined behavior.
3309  auto paddingValue = getPaddingValue();
3310  if (paddingValue &&
3311  paddingValue.getType() != getSourceType().getElementType()) {
3312  return emitOpError("expected padding_value has ")
3313  << getSourceType().getElementType()
3314  << " but got: " << paddingValue.getType();
3315  }
3316 
3317  auto dimAndTileMapping = getDimAndTileMapping();
3318  if (!paddingValue &&
3319  areNotFullTiles(getSourceType().getShape(), dimAndTileMapping)) {
3320  return emitOpError("invalid tile factor provided. Only full tiles are "
3321  "supported when padding_value is not set");
3322  }
3323  return success();
3324 }
3325 
3326 /// Get the expected packed type based on source type, tile factors, position of
3327 /// the inner tiles and permutation of the outer tiled loop.
3328 ShapedType PackOp::inferPackedType(ShapedType sourceType,
3329  ArrayRef<int64_t> innerTileSizes,
3330  ArrayRef<int64_t> innerDimsPos,
3331  ArrayRef<int64_t> outerDimsPerm) {
3332  SmallVector<int64_t> resultShape = llvm::to_vector(sourceType.getShape());
3333  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3334  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3335  continue;
3336  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3337  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3338  continue;
3339  }
3340  resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
3341  innerTileSizes[tiledDim.index()]);
3342  }
3343 
3344  if (!outerDimsPerm.empty())
3345  applyPermutationToVector(resultShape, outerDimsPerm);
3346 
3347  // Append the inner tile dimensions.
3348  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3349  return RankedTensorType::get(resultShape, sourceType.getElementType());
3350 }
3351 
3352 /// Returns true if the tiles and the tiled dims are constant.
3353 template <typename OpTy>
3356  "applies to only pack or unpack operations");
3357  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3358  ? op.getDestType()
3359  : op.getSourceType();
3360  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
3361  for (auto [dimDest, tile] : llvm::zip(
3362  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3363  Optional<int64_t> constTileSize = getConstantIntValue(tile);
3364  if (!constTileSize || ShapedType::isDynamic(dimDest))
3365  return false;
3366  }
3367  return true;
3368 }
3369 
3370 Speculation::Speculatability PackOp::getSpeculatability() {
3371  if (auto paddingValue = getPaddingValue())
3373 
3374  // The verifier rejects already operations if we can statically prove that the
3375  // sizes of the tiles do not divide perfectly the dimension; thus, check only
3376  // to have constant tiles and tiled inner dimensions.
3377  if (!areTilesAndTiledDimsAllConstant(*this))
3379 
3381 }
3382 
3383 //===----------------------------------------------------------------------===//
3384 // UnPackOp
3385 //===----------------------------------------------------------------------===//
3386 
3387 void UnPackOp::getAsmResultNames(
3388  function_ref<void(Value, StringRef)> setNameFn) {
3389  setNameFn(getResult(), "unpack");
3390 }
3391 
3393 UnPackOp::reifyResultShapes(OpBuilder &builder,
3394  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3395  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3396 }
3397 
3398 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
3399  return getDimAndTileMappingImpl(*this);
3400 }
3401 
3402 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
3403  return getMixedTilesImpl(*this);
3404 }
3405 
3406 SmallVector<int64_t> UnPackOp::getStaticTiles() {
3407  return getStaticTilesImpl(*this);
3408 }
3409 
3411  return commonVerifierPackAndUnPackOp(*this);
3412 }
3413 
3414 Speculation::Speculatability UnPackOp::getSpeculatability() {
3415  // See PackOp::getSpeculatability.
3416  if (!areTilesAndTiledDimsAllConstant(*this))
3418 
3420 }
3421 
3422 /// pack(unpack(x)) -> x
3423 LogicalResult UnPackOp::canonicalize(UnPackOp unpackOp,
3424  PatternRewriter &rewriter) {
3425  PackOp packOp = unpackOp.getSource().getDefiningOp<tensor::PackOp>();
3426  if (!packOp || packOp.getDestType() != unpackOp.getSourceType())
3427  return failure();
3428  if (packOp.getInnerDimsPos() != unpackOp.getInnerDimsPos())
3429  return failure();
3430  if (packOp.getOuterDimsPerm() != unpackOp.getOuterDimsPerm())
3431  return failure();
3432  rewriter.replaceOp(unpackOp, packOp.getSource());
3433  return success();
3434 }
3435 
3436 //===----------------------------------------------------------------------===//
3437 // TableGen'd op method definitions
3438 //===----------------------------------------------------------------------===//
3439 
3440 #define GET_OP_CLASSES
3441 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static constexpr const bool value
Operation::operand_range getIndices(Operation *op)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:3354
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:240
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:986
static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, ShapedType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2157
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2)
Definition: TensorOps.cpp:1370
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: TensorOps.cpp:1763
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3069
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3100
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3084
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2196
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2044
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2066
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3131
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1252
static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef< AffineMap > reassociation)
Compute the RankedTensorType obtained by applying reassociation to type.
Definition: TensorOps.cpp:1331
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3146
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3055
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3114
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2469
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2216
static bool areNotFullTiles(ArrayRef< int64_t > inputShape, DenseMap< int64_t, OpFoldResult > const &dimAndTileMapping)
Check if we have enough static information to catch undefined behavior when the tile size does not di...
Definition: TensorOps.cpp:3284
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1383
ParseResult parseInferType(OpAsmParser &parser, Optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2472
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
U cast() const
Definition: Attributes.h:137
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
UnitAttr getUnitAttr()
Definition: Builders.cpp:99
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:157
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:331
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:113
MLIRContext * getContext() const
Definition: Builders.h:54
IndexType getIndexType()
Definition: Builders.cpp:56
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense float vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:40
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
result_range getResults()
Definition: Operation.h:332
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:214
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:225
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:242
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:517
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isIndex() const
Definition: Types.cpp:30
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt ceilDiv(const MPInt &lhs, const MPInt &rhs)
Definition: MPInt.h:374
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:245
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:211
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2004
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:202
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:172
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2446
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2092
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:61
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:146
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:124
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:46
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:100
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:346
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1701
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:352
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Definition: IndexingUtils.h:70
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2023
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2024
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2011
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2012
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.