MLIR  18.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include <algorithm>
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::tensor;
36 
37 /// Materialize a single constant operation from a given attribute value with
38 /// the desired resultant type.
40  Attribute value, Type type,
41  Location loc) {
42  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
43  return op;
44  if (complex::ConstantOp::isBuildableWith(value, type))
45  return builder.create<complex::ConstantOp>(loc, type,
46  llvm::cast<ArrayAttr>(value));
47  return nullptr;
48 }
49 
51  int64_t dim) {
52  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
54  if (tensorType.isDynamicDim(dim))
55  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
56 
57  return builder.getIndexAttr(tensorType.getDimSize(dim));
58 }
59 
61  Location loc, Value value) {
62  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
64  for (int64_t i = 0; i < tensorType.getRank(); ++i)
65  result.push_back(getMixedSize(builder, loc, value, i));
66  return result;
67 }
68 
70  OpResult opResult) {
71  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
72  assert(tensorType && "expected tensor type");
73 
74  // If the op has a destination, it implements DestinationStyleOpInterface and
75  // we can query the destination operand from that interface.
76  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
77  if (destOp)
78  return destOp.getTiedOpOperand(opResult)->get();
79 
80  // Otherwise, create a new destination tensor with the same shape.
82  b.setInsertionPoint(opResult.getDefiningOp());
83 
84  // Compute sizes.
85  SmallVector<OpFoldResult> mixedSizes;
86  if (!tensorType.hasStaticShape()) {
87  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
88  ReifiedRankedShapedTypeDims reifiedShapes;
89  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
90  return failure();
91  mixedSizes = reifiedShapes[opResult.getResultNumber()];
92  } else {
93  // Static shape: Take static sizes directly.
94  for (int64_t sz : tensorType.getShape())
95  mixedSizes.push_back(b.getIndexAttr(sz));
96  }
97 
98  // Create empty tensor.
99  Value emptyTensor =
100  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
101  return emptyTensor;
102 }
103 
105  Operation *op,
106  SmallVector<Value> &result) {
107  for (OpResult opResult : op->getResults()) {
108  if (llvm::isa<TensorType>(opResult.getType())) {
109  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
110  if (failed(destination))
111  return failure();
112  result.push_back(*destination);
113  }
114  }
115  return success();
116 }
117 
119  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
120  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
121  return rtp1.getShape() == rtp2.getShape() &&
122  rtp1.getElementType() == rtp2.getElementType();
123  return false;
124  }
125  return tp1 == tp2; // default implementation
126 }
127 
128 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
129 /// rank-extending tensor.insert_slice op.
130 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
131  ArrayRef<OpFoldResult> mixedSizes) {
132  llvm::SmallBitVector droppedDims(mixedSizes.size());
133  int64_t shapePos = 0;
134 
135  for (const auto &size : enumerate(mixedSizes)) {
136  // Rank-reduced dims must have a static unit dimension.
137  bool isStaticUnitSize =
138  size.value().is<Attribute>() &&
139  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
140 
141  if (shapePos == static_cast<int64_t>(reducedShape.size())) {
142  // There are no more dims in the reduced shape. All remaining sizes must
143  // be rank-reduced dims.
144  assert(isStaticUnitSize && "expected unit dim");
145  droppedDims.set(size.index());
146  continue;
147  }
148 
149  // Dim is preserved if the size is not a static 1.
150  if (!isStaticUnitSize) {
151  ++shapePos;
152  continue;
153  }
154 
155  // Dim is preserved if the reduced shape dim is also 1.
156  if (reducedShape[shapePos] == 1) {
157  ++shapePos;
158  continue;
159  }
160 
161  // Otherwise: Dim is dropped.
162  droppedDims.set(size.index());
163  }
164 
165  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
166  "dimension mismatch");
167  return droppedDims;
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // BitcastOp
172 //===----------------------------------------------------------------------===//
173 
174 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
175  if (inputs.size() != 1 || outputs.size() != 1)
176  return false;
177  Type a = inputs.front(), b = outputs.front();
178  auto aT = dyn_cast<TensorType>(a);
179  auto bT = dyn_cast<TensorType>(b);
180  if (!aT || !bT)
181  return false;
182 
183  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
184  return false;
185 
186  return succeeded(verifyCompatibleShape(aT, bT));
187 }
188 
189 namespace {
190 
191 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
192 /// operation.
193 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
195 
196  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
197  PatternRewriter &rewriter) const final {
198  auto tensorBitcastOperand =
199  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
200  if (!tensorBitcastOperand)
201  return failure();
202 
203  auto resultType = cast<TensorType>(tensorBitcast.getType());
204  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
205  tensorBitcastOperand.getOperand());
206  return success();
207  }
208 };
209 
210 } // namespace
211 
212 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
213  MLIRContext *context) {
214  results.add<ChainedTensorBitcast>(context);
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // CastOp
219 //===----------------------------------------------------------------------===//
220 
221 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
222  setNameFn(getResult(), "cast");
223 }
224 
225 /// Returns true if `target` is a ranked tensor type that preserves static
226 /// information available in the `source` ranked tensor type.
228  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
229  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
230 
231  // Requires RankedTensorType.
232  if (!sourceType || !targetType)
233  return false;
234 
235  // Requires same elemental type.
236  if (sourceType.getElementType() != targetType.getElementType())
237  return false;
238 
239  // Requires same rank.
240  if (sourceType.getRank() != targetType.getRank())
241  return false;
242 
243  // If cast is towards more static sizes along any dimension, don't fold.
244  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
245  if (!ShapedType::isDynamic(std::get<0>(t)) &&
246  ShapedType::isDynamic(std::get<1>(t)))
247  return false;
248  }
249 
250  return true;
251 }
252 
253 /// Determines whether tensor::CastOp casts to a more dynamic version of the
254 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
255 /// implement canonicalization patterns for ops in different dialects that may
256 /// consume the results of tensor.cast operations. Such foldable tensor.cast
257 /// operations are typically inserted as `slice` ops and are canonicalized,
258 /// to preserve the type compatibility of their uses.
259 ///
260 /// Returns true when all conditions are met:
261 /// 1. source and result are ranked tensors with same element type and rank.
262 /// 2. the tensor type has more static information than the result
263 ///
264 /// Example:
265 /// ```mlir
266 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
267 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
268 /// ```
269 ///
270 /// folds into:
271 ///
272 /// ```mlir
273 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
274 /// ```
276  if (!castOp)
277  return false;
278 
279  // Can fold if the source of cast has at least as much static information as
280  // its results.
281  return preservesStaticInformation(castOp.getType(),
282  castOp.getSource().getType());
283 }
284 
285 /// Determines whether the tensor::CastOp casts to a more static version of the
286 /// source tensor. This is useful to fold into a producing op and implement
287 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
288 /// being from different dialects. Returns true when all conditions are met:
289 /// 1. source and result and ranked tensors with same element type and rank.
290 /// 2. the result type has more static information than the source.
291 ///
292 /// Example:
293 /// ```mlir
294 /// %1 = producer ... : tensor<?x?xf32>
295 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
296 /// ```
297 ///
298 /// can be canonicalized to :
299 ///
300 /// ```mlir
301 /// %2 = producer ... : tensor<8x16xf32>
302 /// ```
303 /// Not all ops might be canonicalizable this way, but for those that can be,
304 /// this method provides a check that it is worth doing the canonicalization.
306  if (!castOp)
307  return false;
308  return preservesStaticInformation(castOp.getSource().getType(),
309  castOp.getType());
310 }
311 
312 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
313 /// that can be folded.
315  bool folded = false;
316  for (OpOperand &operand : op->getOpOperands()) {
317  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
318  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
319  operand.set(castOp.getOperand());
320  folded = true;
321  }
322  }
323  return success(folded);
324 }
325 
326 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
327  if (inputs.size() != 1 || outputs.size() != 1)
328  return false;
329  Type a = inputs.front(), b = outputs.front();
330  auto aT = llvm::dyn_cast<TensorType>(a);
331  auto bT = llvm::dyn_cast<TensorType>(b);
332  if (!aT || !bT)
333  return false;
334 
335  if (aT.getElementType() != bT.getElementType())
336  return false;
337 
338  return succeeded(verifyCompatibleShape(aT, bT));
339 }
340 
341 /// Compute a TensorType that has the joined shape knowledge of the two
342 /// given TensorTypes. The element types need to match.
344  assert(one.getElementType() == two.getElementType());
345 
346  if (!one.hasRank())
347  return two;
348  if (!two.hasRank())
349  return one;
350 
351  int64_t rank = one.getRank();
352  if (rank != two.getRank())
353  return {};
354 
356  join.reserve(rank);
357  for (int64_t i = 0; i < rank; ++i) {
358  if (one.isDynamicDim(i)) {
359  join.push_back(two.getDimSize(i));
360  continue;
361  }
362  if (two.isDynamicDim(i)) {
363  join.push_back(one.getDimSize(i));
364  continue;
365  }
366  if (one.getDimSize(i) != two.getDimSize(i))
367  return {};
368  join.push_back(one.getDimSize(i));
369  }
370  return RankedTensorType::get(join, one.getElementType());
371 }
372 
373 namespace {
374 
375 /// Replaces chains of two tensor.cast operations by a single tensor.cast
376 /// operation if doing so does not remove runtime constraints.
377 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
379 
380  LogicalResult matchAndRewrite(CastOp tensorCast,
381  PatternRewriter &rewriter) const final {
382  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
383 
384  if (!tensorCastOperand)
385  return failure();
386 
387  auto sourceType =
388  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
389  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
390  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
391 
392  // We can remove the intermediate cast if joining all three produces the
393  // same result as just joining the source and result shapes.
394  auto firstJoin =
395  joinShapes(joinShapes(sourceType, intermediateType), resultType);
396 
397  // The join might not exist if the cast sequence would fail at runtime.
398  if (!firstJoin)
399  return failure();
400 
401  // The newJoin always exists if the above join exists, it might just contain
402  // less information. If so, we cannot drop the intermediate cast, as doing
403  // so would remove runtime checks.
404  auto newJoin = joinShapes(sourceType, resultType);
405  if (firstJoin != newJoin)
406  return failure();
407 
408  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
409  tensorCastOperand.getOperand());
410  return success();
411  }
412 };
413 
414 /// Fold tensor.cast into tesor.extract_slice producer.
415 /// Example:
416 /// ```
417 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
418 /// tensor<128x512xf32> to tensor<?x512xf32>
419 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
420 /// ```
421 /// ->
422 /// ```
423 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
424 /// tensor<128x512xf32> to tensor<16x512xf32>
425 /// ```
426 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
428 
429  LogicalResult matchAndRewrite(CastOp tensorCast,
430  PatternRewriter &rewriter) const final {
431  auto extractOperand =
432  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
433 
434  // Cannot fold cast to unranked tensor.
435  auto rankedResultType =
436  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
437  if (!rankedResultType)
438  return failure();
439 
440  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
441  rankedResultType.getShape() ==
442  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
443  .getShape())
444  return failure();
445 
446  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
447  auto dimMask = computeRankReductionMask(
448  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
449  size_t dimIndex = 0;
450  for (size_t i = 0, e = sizes.size(); i < e; i++) {
451  if (dimMask && dimMask->count(i))
452  continue;
453  int64_t dim = rankedResultType.getShape()[dimIndex++];
454  if (ShapedType::isDynamic(dim))
455  continue;
456  sizes[i] = rewriter.getIndexAttr(dim);
457  }
458 
459  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
460  tensorCast, rankedResultType, extractOperand.getSource(),
461  extractOperand.getMixedOffsets(), sizes,
462  extractOperand.getMixedStrides());
463  return success();
464  }
465 };
466 
467 } // namespace
468 
469 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
470  MLIRContext *context) {
471  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // DimOp
476 //===----------------------------------------------------------------------===//
477 
478 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
479  setNameFn(getResult(), "dim");
480 }
481 
482 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
483  int64_t index) {
484  auto loc = result.location;
485  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
486  build(builder, result, source, indexValue);
487 }
488 
489 std::optional<int64_t> DimOp::getConstantIndex() {
490  return getConstantIntValue(getIndex());
491 }
492 
493 Speculation::Speculatability DimOp::getSpeculatability() {
494  auto constantIndex = getConstantIndex();
495  if (!constantIndex)
497 
498  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
499  if (!rankedSourceType)
501 
502  // The verifier rejects operations that violate this assertion.
503  assert(constantIndex < rankedSourceType.getRank());
505 }
506 
507 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
508  // All forms of folding require a known index.
509  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
510  if (!index)
511  return {};
512 
513  // Folding for unranked types (UnrankedTensorType) is not supported.
514  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
515  if (!tensorType)
516  return {};
517 
518  // Out of bound indices produce undefined behavior but are still valid IR.
519  // Don't choke on them.
520  int64_t indexVal = index.getInt();
521  if (indexVal < 0 || indexVal >= tensorType.getRank())
522  return {};
523 
524  // Fold if the shape extent along the given index is known.
525  if (!tensorType.isDynamicDim(index.getInt())) {
526  Builder builder(getContext());
527  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
528  }
529 
530  Operation *definingOp = getSource().getDefiningOp();
531 
532  // Fold dim to the operand of tensor.generate.
533  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
534  auto resultType =
535  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
536  // The case where the type encodes the size of the dimension is handled
537  // above.
538  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
539 
540  // Find the operand of the fromElements that corresponds to this index.
541  auto dynExtents = fromElements.getDynamicExtents().begin();
542  for (auto dim : resultType.getShape().take_front(index.getInt()))
543  if (ShapedType::isDynamic(dim))
544  dynExtents++;
545 
546  return Value{*dynExtents};
547  }
548 
549  // The size at the given index is now known to be a dynamic size.
550  unsigned unsignedIndex = index.getValue().getZExtValue();
551 
552  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
553  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
554  // `resolve-shaped-type-result-dims` pass.
555  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
556  sliceOp.isDynamicSize(unsignedIndex)) {
557  return {sliceOp.getDynamicSize(unsignedIndex)};
558  }
559  }
560 
561  // dim(cast) -> dim
562  if (succeeded(foldTensorCast(*this)))
563  return getResult();
564 
565  return {};
566 }
567 
568 namespace {
569 /// Fold dim of a cast into the dim of the source of the tensor cast.
570 struct DimOfCastOp : public OpRewritePattern<DimOp> {
572 
573  LogicalResult matchAndRewrite(DimOp dimOp,
574  PatternRewriter &rewriter) const override {
575  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
576  if (!castOp)
577  return failure();
578  Value newSource = castOp.getOperand();
579  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
580  return success();
581  }
582 };
583 
584 /// Fold dim of a destination passing style op into the dim of the corresponding
585 /// init.
586 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
588 
589  LogicalResult matchAndRewrite(DimOp dimOp,
590  PatternRewriter &rewriter) const override {
591  auto source = dimOp.getSource();
592  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
593  if (!destOp)
594  return failure();
595 
596  auto resultIndex = source.cast<OpResult>().getResultNumber();
597  auto initOperand = destOp.getDpsInitOperand(resultIndex);
598 
599  rewriter.updateRootInPlace(
600  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
601  return success();
602  }
603 };
604 } // namespace
605 
606 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
607  MLIRContext *context) {
608  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // EmptyOp
613 //===----------------------------------------------------------------------===//
614 
615 void EmptyOp::build(OpBuilder &builder, OperationState &result,
616  ArrayRef<int64_t> staticShape, Type elementType,
617  Attribute encoding) {
618  assert(all_of(staticShape,
619  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
620  "expected only static sizes");
621  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
622 }
623 
624 void EmptyOp::build(OpBuilder &builder, OperationState &result,
625  ArrayRef<int64_t> staticShape, Type elementType,
626  ValueRange dynamicSizes, Attribute encoding) {
627  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
628  build(builder, result, tensorType, dynamicSizes);
629 }
630 
631 void EmptyOp::build(OpBuilder &builder, OperationState &result,
632  ArrayRef<OpFoldResult> sizes, Type elementType,
633  Attribute encoding) {
634  SmallVector<int64_t> staticShape;
635  SmallVector<Value> dynamicSizes;
636  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
637  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
638 }
639 
641  if (getType().getNumDynamicDims() !=
642  static_cast<int64_t>(getDynamicSizes().size()))
643  return emitOpError("incorrect number of dynamic sizes, has ")
644  << getDynamicSizes().size() << ", expected "
645  << getType().getNumDynamicDims();
646  return success();
647 }
648 
651  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
652  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
653  unsigned ctr = 0;
654  for (int64_t i = 0; i < getType().getRank(); ++i) {
655  if (getType().isDynamicDim(i)) {
656  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
657  } else {
658  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
659  }
660  }
661  return success();
662 }
663 
664 Value EmptyOp::getDynamicSize(unsigned idx) {
665  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
666  unsigned ctr = 0;
667  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
668  if (getType().isDynamicDim(i))
669  ++ctr;
670  return getDynamicSizes()[ctr];
671 }
672 
675  unsigned ctr = 0;
676  OpBuilder b(getContext());
677  for (int64_t i = 0; i < getType().getRank(); ++i) {
678  if (getType().isDynamicDim(i)) {
679  result.push_back(getDynamicSizes()[ctr++]);
680  } else {
681  result.push_back(b.getIndexAttr(getType().getShape()[i]));
682  }
683  }
684  return result;
685 }
686 
687 namespace {
688 /// Change the type of the result of a `tensor.empty` by making the result
689 /// type statically sized along dimensions that in the original operation were
690 /// defined as dynamic, but the size was defined using a `constant` op. For
691 /// example
692 ///
693 /// %c5 = arith.constant 5: index
694 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
695 ///
696 /// to
697 ///
698 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
699 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
701 
702  LogicalResult matchAndRewrite(EmptyOp op,
703  PatternRewriter &rewriter) const override {
704  SmallVector<int64_t> staticShape(op.getType().getShape().begin(),
705  op.getType().getShape().end());
706  SmallVector<Value> dynamicSizes;
707 
708  // Compute new static and dynamic sizes.
709  unsigned ctr = 0;
710  bool changedType = false;
711  for (int64_t i = 0; i < op.getType().getRank(); ++i) {
712  if (op.getType().isDynamicDim(i)) {
713  Value dynamicSize = op.getDynamicSizes()[ctr++];
714  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
715  if (cst.has_value()) {
716  // dynamic size must be non-negative.
717  if (cst.value() < 0)
718  return failure();
719  staticShape[i] = *cst;
720  changedType = true;
721  } else {
722  dynamicSizes.push_back(dynamicSize);
723  }
724  }
725  }
726 
727  // Stop here if no dynamic size was promoted to static.
728  if (!changedType)
729  return failure();
730 
731  auto tensorType = RankedTensorType::get(
732  staticShape, op.getType().getElementType(), op.getType().getEncoding());
733  auto newOp =
734  rewriter.create<EmptyOp>(op.getLoc(), tensorType, dynamicSizes);
735  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
736  return success();
737  }
738 };
739 
740 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
742 
743  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
744  PatternRewriter &rewriter) const override {
745  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
746  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
747  if (!emptyTensorOp || !maybeConstantIndex)
748  return failure();
749  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
750  return failure();
751  rewriter.replaceOp(dimOp,
752  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
753  return success();
754  }
755 };
756 
757 /// Canonicalize
758 ///
759 /// ```mlir
760 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
761 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
762 /// ```
763 ///
764 /// into
765 ///
766 /// ```mlir
767 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
768 /// ```
769 ///
770 /// This assumes the input program is correct in terms of its shape. So it is
771 /// safe to assume that `%d0` is in fact 4.
772 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
774 
775  LogicalResult matchAndRewrite(CastOp castOp,
776  PatternRewriter &rewriter) const override {
777  if (!canFoldIntoProducerOp(castOp))
778  return failure();
779  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
780  if (!producer)
781  return failure();
782 
783  auto resultType =
784  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
785  ArrayRef<int64_t> resultShape = resultType.getShape();
786  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
787  SmallVector<OpFoldResult> newMixedSizes;
788  newMixedSizes.reserve(currMixedSizes.size());
789  assert(resultShape.size() == currMixedSizes.size() &&
790  "mismatch in result shape and sizes of empty op");
791  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
792  int64_t newDim = std::get<0>(it);
793  OpFoldResult currDim = std::get<1>(it);
794  // Case 1: The empty tensor dim is static. Check that the tensor cast
795  // result dim matches.
796  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
797  if (ShapedType::isDynamic(newDim) ||
798  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
799  // Something is off, the cast result shape cannot be more dynamic
800  // than the empty tensor result shape (enforced by
801  // `canFoldIntoProducer`). Abort for now.
802  return rewriter.notifyMatchFailure(
803  producer, "mismatch in static value of shape of empty tensor "
804  "result and cast result");
805  }
806  newMixedSizes.push_back(attr);
807  continue;
808  }
809 
810  // Case 2 : The tensor cast shape is static, but empty tensor result
811  // shape is dynamic.
812  if (!ShapedType::isDynamic(newDim)) {
813  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
814  continue;
815  }
816 
817  // Case 3 : The tensor cast shape is dynamic and empty tensor result
818  // shape is dynamic. Use the dynamic value from the empty tensor op.
819  newMixedSizes.push_back(currDim);
820  }
821 
822  // TODO: Do not drop tensor encoding.
823  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
824  resultType.getElementType());
825  return success();
826  }
827 };
828 
829 } // namespace
830 
831 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
832  MLIRContext *context) {
833  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
834  ReplaceEmptyTensorStaticShapeDims>(context);
835 }
836 
837 //===----------------------------------------------------------------------===//
838 // ExtractOp
839 //===----------------------------------------------------------------------===//
840 
841 namespace {
842 
843 /// Canonicalizes the pattern of the form
844 ///
845 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
846 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
847 ///
848 /// to
849 ///
850 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
851 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
853 
854  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
855  PatternRewriter &rewriter) const final {
856  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
857  if (!tensorCast)
858  return failure();
859  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
860  return failure();
861  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
862  extract, tensorCast.getSource(), extract.getIndices());
863  return success();
864  }
865 };
866 
867 } // namespace
868 
869 void ExtractOp::getAsmResultNames(
870  function_ref<void(Value, StringRef)> setNameFn) {
871  setNameFn(getResult(), "extracted");
872 }
873 
875  // Verify the # indices match if we have a ranked type.
876  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
877  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
878  return emitOpError("incorrect number of indices for extract_element");
879  return success();
880 }
881 
882 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
883  // If this is a splat elements attribute, simply return the value. All of
884  // the elements of a splat attribute are the same.
885  if (Attribute tensor = adaptor.getTensor())
886  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
887  return splatTensor.getSplatValue<Attribute>();
888 
889  // Collect the constant indices into the tensor.
890  SmallVector<uint64_t, 8> indices;
891  for (Attribute indice : adaptor.getIndices()) {
892  if (!indice || !llvm::isa<IntegerAttr>(indice))
893  return {};
894  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
895  }
896 
897  // Fold extract(from_elements(...)).
898  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
899  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
900  auto rank = tensorType.getRank();
901  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
902  "rank mismatch");
903  int flatIndex = 0;
904  int stride = 1;
905  for (int i = rank - 1; i >= 0; --i) {
906  if (i < rank - 1)
907  stride *= tensorType.getDimSize(i);
908  flatIndex += indices[i] * stride;
909  }
910  // Prevent out of bounds accesses. This can happen in invalid code that
911  // will never execute.
912  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
913  flatIndex < 0)
914  return {};
915  return fromElementsOp.getElements()[flatIndex];
916  }
917 
918  // If this is an elements attribute, query the value at the given indices.
919  if (Attribute tensor = adaptor.getTensor()) {
920  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
921  if (elementsAttr && elementsAttr.isValidIndex(indices))
922  return elementsAttr.getValues<Attribute>()[indices];
923  }
924 
925  return {};
926 }
927 
928 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
929  MLIRContext *context) {
930  results.add<ExtractFromTensorCast>(context);
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // FromElementsOp
935 //===----------------------------------------------------------------------===//
936 
937 void FromElementsOp::getAsmResultNames(
938  function_ref<void(Value, StringRef)> setNameFn) {
939  setNameFn(getResult(), "from_elements");
940 }
941 
942 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
943  ValueRange elements) {
944  assert(!elements.empty() && "expected at least one element");
945  Type resultType = RankedTensorType::get(
946  {static_cast<int64_t>(elements.size())}, elements.front().getType());
947  build(builder, result, resultType, elements);
948 }
949 
950 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
951  if (!llvm::is_contained(adaptor.getElements(), nullptr))
952  return DenseElementsAttr::get(getType(), adaptor.getElements());
953  return {};
954 }
955 
956 namespace {
957 
958 // Pushes the index_casts that occur before extractions to after the extract.
959 // This minimizes type conversion in some cases and enables the extract
960 // canonicalizer. This changes:
961 //
962 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
963 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
964 //
965 // to the following:
966 //
967 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
968 // %cast = arith.index_cast %extract : i32 to index
969 //
970 // to just %element.
971 //
972 // Consider expanding this to a template and handle all tensor cast
973 // operations.
974 struct ExtractElementFromIndexCast
975  : public OpRewritePattern<tensor::ExtractOp> {
977 
978  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
979  PatternRewriter &rewriter) const final {
980  Location loc = extract.getLoc();
981  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
982  if (!indexCast)
983  return failure();
984 
985  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
986 
987  auto newExtract = rewriter.create<tensor::ExtractOp>(
988  loc, elementTy, indexCast.getIn(), extract.getIndices());
989 
990  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
991  newExtract);
992 
993  return success();
994  }
995 };
996 
997 } // namespace
998 
999 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1000  MLIRContext *context) {
1001  results.add<ExtractElementFromIndexCast>(context);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // GatherOp
1006 //===----------------------------------------------------------------------===//
1007 
1008 void GatherOp::getAsmResultNames(
1009  function_ref<void(Value, StringRef)> setNameFn) {
1010  setNameFn(getResult(), "gather");
1011 }
1012 
1013 /// Return the inferred result type for a gatherOp where:
1014 /// - sourceType is the type of the source tensor gathered from
1015 /// - indicesType is the type of the indices used to gather
1016 /// - gatherDims are the dims along which the gather occurs.
1017 /// Return a full rank or ranked-reduced variant of the type depending on
1018 /// the value of rankReduced.
1019 ///
1020 /// The leading dimensions of the index tensor give the result tensor its
1021 /// leading dimensions.
1022 /// The trailing dimensions of the result tensor are obtained from the source
1023 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1024 /// rankedReduced is false), or skipping them (otherwise).
1025 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1026  RankedTensorType indicesType,
1027  ArrayRef<int64_t> gatherDims,
1028  bool rankReduced) {
1029  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1030  resultShape.reserve(resultShape.size() + sourceType.getRank());
1031  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1032  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1033  if (!rankReduced)
1034  resultShape.push_back(1);
1035  continue;
1036  }
1037  resultShape.push_back(sourceType.getDimSize(idx));
1038  }
1039  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1040 }
1041 
1042 static LogicalResult
1044  StringRef gatherOrScatter, StringRef sourceOrDest) {
1045  if (dims.empty())
1046  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1047 
1048  int64_t numGatherDims = dims.size();
1049  if (numGatherDims > rank)
1050  return op->emitOpError(gatherOrScatter)
1051  << "_dims overflow " << sourceOrDest << " rank";
1052  for (int64_t val : dims) {
1053  if (val < 0)
1054  return op->emitOpError(gatherOrScatter)
1055  << "_dims value must be non-negative";
1056  if (val >= rank)
1057  return op->emitOpError(gatherOrScatter)
1058  << "_dims value must be smaller than " << sourceOrDest << " rank";
1059  }
1060  for (int64_t i = 1; i < numGatherDims; ++i) {
1061  if (dims[i - 1] >= dims[i])
1062  return op->emitOpError(gatherOrScatter)
1063  << "_dims values must be strictly increasing";
1064  }
1065  return success();
1066 }
1067 
1069  int64_t sourceRank = getSourceType().getRank();
1070  ArrayRef<int64_t> gatherDims = getGatherDims();
1071  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1072  "gather", "source")))
1073  return failure();
1074 
1075  RankedTensorType expectedResultType = GatherOp::inferResultType(
1076  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1077  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1078  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1079  if (getResultType() != expectedResultType &&
1080  getResultType() != expectedRankReducedResultType) {
1081  return emitOpError("result type "
1082  "mismatch: "
1083  "expected ")
1084  << expectedResultType << " or its rank-reduced variant "
1085  << expectedRankReducedResultType << " (got: " << getResultType()
1086  << ")";
1087  }
1088 
1089  return success();
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // InsertOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 void InsertOp::getAsmResultNames(
1097  function_ref<void(Value, StringRef)> setNameFn) {
1098  setNameFn(getResult(), "inserted");
1099 }
1100 
1102  // Verify the # indices match if we have a ranked type.
1103  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1104  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1105  return emitOpError("incorrect number of indices");
1106  return success();
1107 }
1108 
1109 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1110  Attribute scalar = adaptor.getScalar();
1111  Attribute dest = adaptor.getDest();
1112  if (scalar && dest)
1113  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1114  if (scalar == splatDest.getSplatValue<Attribute>())
1115  return dest;
1116  return {};
1117 }
1118 
1119 //===----------------------------------------------------------------------===//
1120 // GenerateOp
1121 //===----------------------------------------------------------------------===//
1122 
1123 void GenerateOp::getAsmResultNames(
1124  function_ref<void(Value, StringRef)> setNameFn) {
1125  setNameFn(getResult(), "generated");
1126 }
1127 
1129  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1130  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1131  int idx = 0;
1132  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1133  if (getType().isDynamicDim(dim)) {
1134  reifiedReturnShapes[0][dim] = getOperand(idx++);
1135  } else {
1136  reifiedReturnShapes[0][dim] =
1137  builder.getIndexAttr(getType().getDimSize(dim));
1138  }
1139  }
1140  return success();
1141 }
1142 
1143 /// Extract operands and shape from a tensor with dynamic extents.
1144 static void operandsAndShape(TensorType resultType,
1145  Operation::operand_range dynamicExtents,
1146  SmallVectorImpl<Value> &newOperands,
1147  SmallVectorImpl<int64_t> &newShape) {
1148  auto operandsIt = dynamicExtents.begin();
1149  for (int64_t dim : resultType.getShape()) {
1150  if (!ShapedType::isDynamic(dim)) {
1151  newShape.push_back(dim);
1152  continue;
1153  }
1154  APInt index;
1155  if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1156  newShape.push_back(ShapedType::kDynamic);
1157  newOperands.push_back(*operandsIt++);
1158  continue;
1159  }
1160  newShape.push_back(index.getSExtValue());
1161  operandsIt++;
1162  }
1163 }
1164 
1166  // Ensure that the tensor type has as many dynamic dimensions as are
1167  // specified by the operands.
1168  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1169  if (getNumOperands() != resultType.getNumDynamicDims())
1170  return emitError("must have as many index operands as dynamic extents "
1171  "in the result type");
1172  // Ensure operands are non-negative.
1173  SmallVector<Value> newOperands;
1174  SmallVector<int64_t> newShape;
1175  operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
1176  for (int64_t newdim : newShape) {
1177  if (newdim < 0 && !ShapedType::isDynamic(newdim))
1178  return emitError("tensor dimensions must be non-negative");
1179  }
1180  return success();
1181 }
1182 
1183 LogicalResult GenerateOp::verifyRegions() {
1184  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1185  // Ensure that region arguments span the index space.
1186  if (!llvm::all_of(getBody().getArgumentTypes(),
1187  [](Type ty) { return ty.isIndex(); }))
1188  return emitError("all body arguments must be index");
1189  if (getBody().getNumArguments() != resultTy.getRank())
1190  return emitError("must have one body argument per input dimension");
1191 
1192  // Ensure that the region yields an element of the right type.
1193  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1194 
1195  if (yieldOp.getValue().getType() != resultTy.getElementType())
1196  return emitOpError(
1197  "body must be terminated with a `yield` operation of the tensor "
1198  "element type");
1199 
1200  return success();
1201 }
1202 
1203 void GenerateOp::build(
1204  OpBuilder &b, OperationState &result, Type resultTy,
1205  ValueRange dynamicExtents,
1206  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1207  build(b, result, resultTy, dynamicExtents);
1208 
1209  // Build and populate body.
1210  OpBuilder::InsertionGuard guard(b);
1211  Region *bodyRegion = result.regions.front().get();
1212  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1213  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1214  SmallVector<Location, 2> argumentLocs(rank, result.location);
1215  Block *bodyBlock =
1216  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1217  bodyBuilder(b, result.location, bodyBlock->getArguments());
1218 }
1219 
1220 namespace {
1221 
1222 /// Canonicalizes tensor.generate operations with a constant
1223 /// operand into the equivalent operation with the operand expressed in the
1224 /// result type, instead. We also insert a type cast to make sure that the
1225 /// resulting IR is still well-typed.
1226 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1228 
1229  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
1230  PatternRewriter &rewriter) const final {
1231  auto resultType =
1232  llvm::cast<RankedTensorType>(tensorFromElements.getResult().getType());
1233 
1234  if (resultType.hasStaticShape())
1235  return failure();
1236 
1237  Operation::operand_range dynamicExtents =
1238  tensorFromElements.getDynamicExtents();
1239  SmallVector<Value> newOperands;
1240  SmallVector<int64_t> newShape;
1241  operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
1242 
1243  for (int64_t newdim : newShape) {
1244  // This check also occurs in the verifier, but we need it here too
1245  // since intermediate passes may have some replaced dynamic dimensions
1246  // by constants.
1247  if (newdim < 0 && !ShapedType::isDynamic(newdim))
1248  return failure();
1249  }
1250 
1251  if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
1252  return failure();
1253 
1254  auto loc = tensorFromElements.getLoc();
1255  auto newOp = rewriter.create<GenerateOp>(
1256  loc, RankedTensorType::get(newShape, resultType.getElementType()),
1257  newOperands);
1258  rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
1259  newOp.getBody().begin());
1260  rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
1261  newOp);
1262  return success();
1263  }
1264 };
1265 
1266 /// Canonicalizes the pattern of the form
1267 ///
1268 /// %tensor = tensor.generate %x {
1269 /// ^bb0(%arg0: index):
1270 /// <computation>
1271 /// yield %1 : index
1272 /// } : tensor<?xindex>
1273 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1274 ///
1275 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1276 /// tensor.generate operation has no side-effects.
1277 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1279 
1280  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1281  PatternRewriter &rewriter) const final {
1282  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1283  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1284  return failure();
1285 
1286  IRMapping mapping;
1287  Block *body = &tensorFromElements.getBody().front();
1288  mapping.map(body->getArguments(), extract.getIndices());
1289  for (auto &op : body->without_terminator())
1290  rewriter.clone(op, mapping);
1291 
1292  auto yield = cast<YieldOp>(body->getTerminator());
1293 
1294  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1295  return success();
1296  }
1297 };
1298 
1299 } // namespace
1300 
1301 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1302  MLIRContext *context) {
1303  // TODO: Move extract pattern to tensor::ExtractOp.
1304  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // RankOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1312  setNameFn(getResult(), "rank");
1313 }
1314 
1315 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1316  // Constant fold rank when the rank of the operand is known.
1317  auto type = getOperand().getType();
1318  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1319  if (shapedType && shapedType.hasRank())
1320  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1321  return IntegerAttr();
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // ReshapeOp
1326 //===----------------------------------------------------------------------===//
1327 
1328 void ReshapeOp::getAsmResultNames(
1329  function_ref<void(Value, StringRef)> setNameFn) {
1330  setNameFn(getResult(), "reshape");
1331 }
1332 
1333 static int64_t getNumElements(ShapedType type) {
1334  int64_t numElements = 1;
1335  for (auto dim : type.getShape())
1336  numElements *= dim;
1337  return numElements;
1338 }
1339 
1341  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1342  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1343 
1344  if (operandType.getElementType() != resultType.getElementType())
1345  return emitOpError("element types of source and destination tensor "
1346  "types should be the same");
1347 
1348  int64_t shapeSize =
1349  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1350  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1351  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1352 
1353  if (resultRankedType) {
1354  if (operandRankedType && resultRankedType.hasStaticShape() &&
1355  operandRankedType.hasStaticShape()) {
1356  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1357  return emitOpError("source and destination tensor should have the "
1358  "same number of elements");
1359  }
1360  if (ShapedType::isDynamic(shapeSize))
1361  return emitOpError("cannot use shape operand with dynamic length to "
1362  "reshape to statically-ranked tensor type");
1363  if (shapeSize != resultRankedType.getRank())
1364  return emitOpError(
1365  "length of shape operand differs from the result's tensor rank");
1366  }
1367  return success();
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // Reassociative reshape ops
1372 //===----------------------------------------------------------------------===//
1373 
1374 void CollapseShapeOp::getAsmResultNames(
1375  function_ref<void(Value, StringRef)> setNameFn) {
1376  setNameFn(getResult(), "collapsed");
1377 }
1378 
1379 void ExpandShapeOp::getAsmResultNames(
1380  function_ref<void(Value, StringRef)> setNameFn) {
1381  setNameFn(getResult(), "expanded");
1382 }
1383 
1384 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1385  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1386  "invalid resultDim");
1387  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1388  if (llvm::is_contained(it.value(), resultDim))
1389  return it.index();
1390  llvm_unreachable("could not find reassociation group");
1391 }
1392 
1393 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1394  return getSymbolLessAffineMaps(getReassociationExprs());
1395 }
1396 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1398  getReassociationIndices());
1399 }
1400 
1401 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1402  return getSymbolLessAffineMaps(getReassociationExprs());
1403 }
1404 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1406  getReassociationIndices());
1407 }
1408 
1409 RankedTensorType CollapseShapeOp::inferCollapsedType(
1410  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1411  return inferCollapsedType(
1413  type.getContext(), reassociation)));
1414 }
1415 
1416 /// Compute the RankedTensorType obtained by applying `reassociation` to
1417 /// `type`.
1418 RankedTensorType
1419 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1420  ArrayRef<AffineMap> reassociation) {
1421  auto shape = type.getShape();
1422  SmallVector<int64_t, 4> newShape;
1423  newShape.reserve(reassociation.size());
1424 
1425  // Use the fact that reassociation is valid to simplify the logic: only use
1426  // each map's rank.
1427  assert(isReassociationValid(reassociation) && "invalid reassociation");
1428  unsigned currentDim = 0;
1429  for (AffineMap m : reassociation) {
1430  unsigned dim = m.getNumResults();
1431  auto band = shape.slice(currentDim, dim);
1432  int64_t size = 1;
1433  if (llvm::is_contained(band, ShapedType::kDynamic))
1434  size = ShapedType::kDynamic;
1435  else
1436  for (unsigned d = 0; d < dim; ++d)
1437  size *= shape[currentDim + d];
1438  newShape.push_back(size);
1439  currentDim += dim;
1440  }
1441 
1442  return RankedTensorType::get(newShape, type.getElementType());
1443 }
1444 
1445 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1446  ArrayRef<ReassociationIndices> reassociation,
1447  ArrayRef<NamedAttribute> attrs) {
1448  auto resultType = inferCollapsedType(
1449  llvm::cast<RankedTensorType>(src.getType()),
1451  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1452  build(b, result, resultType, src, attrs);
1453  result.addAttribute(getReassociationAttrStrName(),
1454  getReassociationIndicesAttribute(b, reassociation));
1455 }
1456 
1457 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1458  TensorReshapeOp, ExpandShapeOp>::value>
1459 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1460  RankedTensorType expandedType,
1461  RankedTensorType collapsedType) {
1462  if (failed(
1463  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1464  return failure();
1465 
1466  auto maps = op.getReassociationMaps();
1467  RankedTensorType expectedType =
1468  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1469  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1470  return op.emitOpError("expected collapsed type to be ")
1471  << expectedType << ", but got " << collapsedType;
1472  return success();
1473 }
1474 
1476  auto srcType = getSrcType();
1477  auto resultType = getResultType();
1478  if (srcType.getRank() >= resultType.getRank())
1479  return emitOpError("expected rank expansion, but found source rank ")
1480  << srcType.getRank() << " >= result rank " << resultType.getRank();
1481 
1482  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
1483 }
1484 
1486  auto srcType = getSrcType();
1487  auto resultType = getResultType();
1488  if (srcType.getRank() <= resultType.getRank())
1489  return emitOpError("expected rank reduction, but found source rank ")
1490  << srcType.getRank() << " <= result rank " << resultType.getRank();
1491 
1492  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1493 }
1494 
1495 namespace {
1496 /// Reshape of a splat constant can be replaced with a constant of the result
1497 /// type.
1498 template <typename TensorReshapeOp>
1499 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1501  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1502  PatternRewriter &rewriter) const override {
1503  DenseElementsAttr attr;
1504  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1505  return failure();
1506  if (!attr || !attr.isSplat())
1507  return failure();
1509  reshapeOp.getResultType(), attr.getRawData());
1510  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1511  return success();
1512  }
1513 };
1514 
1515 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1516 template <typename TensorReshapeOp>
1517 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1518 public:
1520 
1521  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1522  PatternRewriter &rewriter) const override {
1523  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1524  if (!splatOp)
1525  return failure();
1526 
1527  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1528  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1529  return success();
1530  }
1531 };
1532 
1533 /// Reshape of a FromElements can be replaced with a FromElements of the
1534 /// result type
1535 template <typename TensorReshapeOp>
1536 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1538  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1539  PatternRewriter &rewriter) const override {
1540  auto fromElements =
1541  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1542  if (!fromElements)
1543  return failure();
1544 
1545  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1546 
1547  if (!shapedTy.hasStaticShape())
1548  return failure();
1549 
1550  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1551  fromElements.getElements());
1552  return success();
1553  }
1554 };
1555 
1556 // Fold CastOp into CollapseShapeOp when adding static information.
1557 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1559 
1560  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1561  PatternRewriter &rewriter) const override {
1562  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1563  if (!tensor::canFoldIntoConsumerOp(castOp))
1564  return failure();
1565 
1566  RankedTensorType srcType =
1567  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1568  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1569  srcType, collapseShapeOp.getReassociationMaps());
1570 
1571  if (newResultType == collapseShapeOp.getResultType()) {
1572  rewriter.updateRootInPlace(collapseShapeOp, [&]() {
1573  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1574  });
1575  } else {
1576  auto newOp = rewriter.create<CollapseShapeOp>(
1577  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1578  collapseShapeOp.getReassociation());
1579  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1580  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1581  }
1582  return success();
1583  }
1584 };
1585 
1586 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1588 
1589  LogicalResult matchAndRewrite(DimOp dimOp,
1590  PatternRewriter &rewriter) const override {
1591  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1592  if (!expandShapeOp)
1593  return failure();
1594 
1595  // Only constant dimension values are supported.
1596  std::optional<int64_t> dim = dimOp.getConstantIndex();
1597  if (!dim.has_value())
1598  return failure();
1599 
1600  // Skip static dims. These are folded to constant ops.
1601  RankedTensorType resultType = expandShapeOp.getResultType();
1602  if (!resultType.isDynamicDim(*dim))
1603  return failure();
1604 
1605  // Find reassociation group that contains this result dimension.
1606  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1607 
1608  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1609  // ExpandShapeOp would be ambiguous.)
1610  int64_t product = 1;
1611  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1612  for (int64_t d : grp) {
1613  if (d != dim) {
1614  assert(!resultType.isDynamicDim(d) && "expected static dim");
1615  product *= resultType.getDimSize(d);
1616  }
1617  }
1618 
1619  // result dim size = src dim size / (product(other dims in reassoc group))
1620  Value srcDimSz =
1621  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1622  AffineExpr expr;
1623  bindSymbols(dimOp.getContext(), expr);
1624  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1625  dimOp, expr.floorDiv(product), srcDimSz);
1626  return success();
1627  }
1628 };
1629 
1630 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1632 
1633  LogicalResult matchAndRewrite(DimOp dimOp,
1634  PatternRewriter &rewriter) const override {
1635  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1636  if (!collapseShapeOp)
1637  return failure();
1638 
1639  // Only constant dimension values are supported.
1640  std::optional<int64_t> dim = dimOp.getConstantIndex();
1641  if (!dim.has_value())
1642  return failure();
1643 
1644  // Skip static dims. These are folded to constant ops.
1645  RankedTensorType resultType = collapseShapeOp.getResultType();
1646  if (!resultType.isDynamicDim(*dim))
1647  return failure();
1648 
1649  // Get reassociation group of the result dimension.
1650  ReassociationIndices group =
1651  collapseShapeOp.getReassociationIndices()[*dim];
1652 
1653  // result dim size = product(dims in reassoc group)
1654  SmallVector<Value> srcDimSizes;
1657  for (const auto &it : llvm::enumerate(group)) {
1658  srcDimSizes.push_back(rewriter.create<DimOp>(
1659  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1660  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1661  product = product ? product * syms.back() : syms.back();
1662  }
1663  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1664  srcDimSizes);
1665  return success();
1666  }
1667 };
1668 } // namespace
1669 
1670 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1671  MLIRContext *context) {
1674  FoldReshapeWithConstant<ExpandShapeOp>,
1675  FoldReshapeWithSplat<ExpandShapeOp>,
1676  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1677  FoldDimOfCollapseShape>(context);
1678 }
1679 
1680 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1681  MLIRContext *context) {
1682  results
1685  FoldReshapeWithConstant<CollapseShapeOp>,
1686  FoldReshapeWithSplat<CollapseShapeOp>,
1687  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1688  context);
1689 }
1690 
1691 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1692  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
1693  adaptor.getOperands());
1694 }
1695 
1696 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1697  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
1698  adaptor.getOperands());
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // ExtractSliceOp
1703 //===----------------------------------------------------------------------===//
1704 
1705 void ExtractSliceOp::getAsmResultNames(
1706  function_ref<void(Value, StringRef)> setNameFn) {
1707  setNameFn(getResult(), "extracted_slice");
1708 }
1709 
1710 /// An extract_slice result type can be inferred, when it is not
1711 /// rank-reduced, from the source type and the static representation of
1712 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
1713 RankedTensorType ExtractSliceOp::inferResultType(
1714  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
1715  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
1716  // An extract_slice op may specify only a leading subset of offset/sizes/
1717  // strides in which case we complete with offset=0, sizes from memref type
1718  // and strides=1.
1719  assert(static_cast<int64_t>(staticSizes.size()) ==
1720  sourceTensorType.getRank() &&
1721  "unexpected staticSizes not equal to rank of source");
1722  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
1723 }
1724 
1725 RankedTensorType ExtractSliceOp::inferResultType(
1726  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
1728  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1729  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1730  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1731  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1732  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1733  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
1734  staticSizes, staticStrides);
1735 }
1736 
1737 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
1738 /// number of sizes), drop as many size 1 as needed to produce an inferred
1739 /// type with the desired rank.
1740 ///
1741 /// Note that there may be multiple ways to compute this rank-reduced type:
1742 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
1743 ///
1744 /// To disambiguate, this function always drops the first 1 sizes occurrences.
1745 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1746  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1747  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1748  ArrayRef<int64_t> strides) {
1749  // Type inferred in the absence of rank-reducing behavior.
1750  auto inferredType = llvm::cast<RankedTensorType>(
1751  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
1752  int rankDiff = inferredType.getRank() - desiredResultRank;
1753  if (rankDiff > 0) {
1754  auto shape = inferredType.getShape();
1755  llvm::SmallBitVector dimsToProject =
1756  getPositionsOfShapeOne(rankDiff, shape);
1757  SmallVector<int64_t> projectedShape;
1758  // Best effort rank-reducing: drop 1s in order.
1759  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1760  if (!dimsToProject.test(pos))
1761  projectedShape.push_back(shape[pos]);
1762  inferredType =
1763  RankedTensorType::get(projectedShape, inferredType.getElementType());
1764  }
1765  return inferredType;
1766 }
1767 
1768 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1769  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1771  ArrayRef<OpFoldResult> strides) {
1772  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1773  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1774  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1775  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1776  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1777  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1778  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1779  staticStrides);
1780 }
1781 
1782 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1783 /// result type. If the type passed is nullptr, it is inferred.
1784 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1785  RankedTensorType resultType, Value source,
1786  ArrayRef<OpFoldResult> offsets,
1787  ArrayRef<OpFoldResult> sizes,
1788  ArrayRef<OpFoldResult> strides,
1789  ArrayRef<NamedAttribute> attrs) {
1790  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1791  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1792  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1793  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1794  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1795  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
1796  // Structuring implementation this way avoids duplication between builders.
1797  if (!resultType) {
1798  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
1799  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
1800  }
1801  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1802  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1803  b.getDenseI64ArrayAttr(staticSizes),
1804  b.getDenseI64ArrayAttr(staticStrides));
1805  result.addAttributes(attrs);
1806 }
1807 
1808 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1809 /// result type.
1810 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1811  ArrayRef<OpFoldResult> offsets,
1812  ArrayRef<OpFoldResult> sizes,
1813  ArrayRef<OpFoldResult> strides,
1814  ArrayRef<NamedAttribute> attrs) {
1815  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1816 }
1817 
1818 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
1819 /// a Range vector.
1820 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1821  ArrayRef<Range> ranges,
1822  ArrayRef<NamedAttribute> attrs) {
1823  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
1824  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1825 }
1826 
1827 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
1828 /// the type passed is nullptr, it is inferred.
1829 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1830  RankedTensorType resultType, Value source,
1831  ValueRange offsets, ValueRange sizes,
1832  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1833  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1834  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1835  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1836  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1837  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1838  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1839  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1840 }
1841 
1842 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
1843 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1844  ValueRange offsets, ValueRange sizes,
1845  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1846  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1847 }
1848 
1850  Operation *op,
1851  RankedTensorType expectedType) {
1852  switch (result) {
1854  return success();
1856  return op->emitError("expected rank to be smaller or equal to ")
1857  << "the other rank. ";
1859  return op->emitError("expected type to be ")
1860  << expectedType << " or a rank-reduced version. (size mismatch) ";
1862  return op->emitError("expected element type to be ")
1863  << expectedType.getElementType();
1864  default:
1865  llvm_unreachable("unexpected extract_slice op verification result");
1866  }
1867 }
1868 
1869 /// Verifier for ExtractSliceOp.
1871  // Verify result type against inferred type.
1872  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1873  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1874  SliceVerificationResult result = isRankReducedType(expectedType, getType());
1875  return produceSliceErrorMsg(result, *this, expectedType);
1876 }
1877 
1878 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1880 }
1881 
1883 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
1884  ArrayRef<int64_t> desiredShape) {
1885  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
1886  assert(sourceTensorType && "not a ranked tensor type");
1887  auto sourceShape = sourceTensorType.getShape();
1888  if (sourceShape.equals(desiredShape))
1889  return value;
1890  auto maybeRankReductionMask =
1891  mlir::computeRankReductionMask(sourceShape, desiredShape);
1892  if (!maybeRankReductionMask)
1893  return failure();
1895  b, loc, value,
1896  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
1897 }
1898 
1900  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1901  reifiedReturnShapes.resize(1);
1902  reifiedReturnShapes[0].reserve(getType().getRank());
1904  llvm::SmallBitVector droppedDims = getDroppedDims();
1905  for (const auto &size : enumerate(mixedSizes)) {
1906  if (droppedDims.test(size.index()))
1907  continue;
1908  reifiedReturnShapes[0].push_back(size.value());
1909  }
1910  return success();
1911 }
1912 
1913 namespace {
1914 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1915 /// This essentially pushes memref_cast past its consuming slice when
1916 /// `canFoldIntoConsumerOp` is true.
1917 ///
1918 /// Example:
1919 /// ```
1920 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1921 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1922 /// tensor<3x4xf32>
1923 /// ```
1924 /// is rewritten into:
1925 /// ```
1926 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1927 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1928 /// ```
1929 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1930 public:
1932 
1933  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1934  PatternRewriter &rewriter) const override {
1935  // Any constant operand, just return to let the constant folder kick in.
1936  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1937  return matchPattern(operand, matchConstantIndex());
1938  }))
1939  return failure();
1940 
1941  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
1942  if (!castOp)
1943  return failure();
1944 
1945  if (!canFoldIntoConsumerOp(castOp))
1946  return failure();
1947 
1948  // Create folded extract.
1949  Location loc = sliceOp.getLoc();
1950  Value newResult = rewriter.create<ExtractSliceOp>(
1951  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
1952  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1953  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1954  if (newResult.getType() != sliceOp.getType())
1955  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
1956  rewriter.replaceOp(sliceOp, newResult);
1957  return success();
1958  }
1959 };
1960 
1961 /// Slice elements from `values` into `outValues`. `counts` represents the
1962 /// numbers of elements to stride in the original values for each dimension.
1963 /// The output values can be used to construct a DenseElementsAttr.
1964 template <typename IterTy, typename ElemTy>
1965 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1966  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1967  ArrayRef<int64_t> strides,
1968  llvm::SmallVectorImpl<ElemTy> *outValues) {
1969  assert(offsets.size() == sizes.size());
1970  assert(offsets.size() == strides.size());
1971  if (offsets.empty())
1972  return;
1973 
1974  int64_t offset = offsets.front();
1975  int64_t size = sizes.front();
1976  int64_t stride = strides.front();
1977  if (offsets.size() == 1) {
1978  for (int64_t i = 0; i < size; ++i, offset += stride)
1979  outValues->push_back(*(values + offset));
1980 
1981  return;
1982  }
1983 
1984  for (int64_t i = 0; i < size; ++i, offset += stride) {
1985  auto begin = values + offset * counts.front();
1986  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1987  offsets.drop_front(), sizes.drop_front(),
1988  strides.drop_front(), outValues);
1989  }
1990 }
1991 
1992 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
1993 /// folded operation might introduce more constant data; Users can control
1994 /// their heuristics by the control function.
1995 class ConstantOpExtractSliceFolder final
1996  : public OpRewritePattern<ExtractSliceOp> {
1997 public:
1999 
2000  ConstantOpExtractSliceFolder(MLIRContext *context,
2002  : OpRewritePattern<ExtractSliceOp>(context),
2003  controlFn(std::move(controlFn)) {}
2004 
2005  LogicalResult matchAndRewrite(ExtractSliceOp op,
2006  PatternRewriter &rewriter) const override {
2007  DenseElementsAttr attr;
2008  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2009  return failure();
2010 
2011  // A constant splat is handled by fold().
2012  if (attr.isSplat())
2013  return failure();
2014 
2015  // Dynamic result shape is not supported.
2016  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2017  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2018  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2019  return failure();
2020 
2021  // Customized control over the folding.
2022  if (!controlFn(op))
2023  return failure();
2024 
2025  int64_t count = sourceType.getNumElements();
2026  if (count == 0)
2027  return failure();
2028 
2029  // Check if there are any dynamic parts, which are not supported.
2030  auto offsets = op.getStaticOffsets();
2031  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2032  return failure();
2033  auto sizes = op.getStaticSizes();
2034  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2035  return failure();
2036  auto strides = op.getStaticStrides();
2037  if (llvm::is_contained(strides, ShapedType::kDynamic))
2038  return failure();
2039 
2040  // Compute the stride for each dimension.
2041  SmallVector<int64_t> counts;
2042  ArrayRef<int64_t> shape = sourceType.getShape();
2043  counts.reserve(shape.size());
2044  for (int64_t v : shape) {
2045  count = count / v;
2046  counts.push_back(count);
2047  }
2048 
2049  // New attribute constructed by the sliced values.
2050  DenseElementsAttr newAttr;
2051 
2052  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2053  SmallVector<APInt> outValues;
2054  outValues.reserve(sourceType.getNumElements());
2055  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2056  elems.begin(), counts, offsets, sizes, strides, &outValues);
2057  newAttr = DenseElementsAttr::get(resultType, outValues);
2058  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2059  SmallVector<APFloat> outValues;
2060  outValues.reserve(sourceType.getNumElements());
2061  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2062  elems.begin(), counts, offsets, sizes, strides, &outValues);
2063  newAttr = DenseElementsAttr::get(resultType, outValues);
2064  }
2065 
2066  if (newAttr) {
2067  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2068  return success();
2069  }
2070 
2071  return failure();
2072  }
2073 
2074 private:
2075  /// This additionally controls whether the fold happens or not. Users can
2076  /// impose their heuristics in the function.
2078 };
2079 
2080 } // namespace
2081 
2083  RewritePatternSet &patterns,
2084  const ControlConstantExtractSliceFusionFn &controlFn) {
2085  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2086 }
2087 
2088 /// Return the canonical type of the result of an extract_slice op.
2090  RankedTensorType operator()(ExtractSliceOp op,
2091  ArrayRef<OpFoldResult> mixedOffsets,
2092  ArrayRef<OpFoldResult> mixedSizes,
2093  ArrayRef<OpFoldResult> mixedStrides) {
2094  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2095  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2096  mixedStrides);
2097  }
2098 };
2099 
2100 /// A canonicalizer wrapper to replace ExtractSliceOps.
2102  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2103  ExtractSliceOp newOp) {
2104  Value replacement = newOp.getResult();
2105  if (replacement.getType() != op.getType())
2106  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2107  replacement);
2108  rewriter.replaceOp(op, replacement);
2109  }
2110 };
2111 
2112 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2113  MLIRContext *context) {
2114  results.add<
2117  ExtractSliceOpCastFolder>(context);
2118 }
2119 
2120 //
2121 static LogicalResult
2122 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2123  ShapedType shapedType) {
2124  OpBuilder b(op.getContext());
2125  for (OpFoldResult ofr : op.getMixedOffsets())
2126  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2127  return failure();
2128  // Rank-reducing noops only need to inspect the leading dimensions:
2129  // llvm::zip is appropriate.
2130  auto shape = shapedType.getShape();
2131  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2132  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2133  return failure();
2134  for (OpFoldResult ofr : op.getMixedStrides())
2135  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2136  return failure();
2137  return success();
2138 }
2139 
2140 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2141 /// slice, we can return the InsertSliceOp's source directly.
2142 // TODO: This only checks the immediate producer; extend to go up the
2143 // insert/extract chain if the slices are disjoint.
2144 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2145  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2146 
2147  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2148  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2149  insertOp.isSameAs(extractOp, isSame))
2150  return insertOp.getSource();
2151 
2152  return {};
2153 }
2154 
2155 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2156  if (auto splat =
2157  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
2158  auto resultType = llvm::cast<ShapedType>(getResult().getType());
2159  if (resultType.hasStaticShape())
2160  return splat.resizeSplat(resultType);
2161  }
2162  if (getSourceType() == getType() &&
2164  return this->getSource();
2165  if (Value slice = foldExtractAfterInsertSlice(*this))
2166  return slice;
2167 
2168  return OpFoldResult();
2169 }
2170 
2172  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2173  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2174  unsigned rank = rankedTensorType.getRank();
2175  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2176  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2177  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2178  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2179  offsets, sizes, strides);
2180 }
2181 
2182 //===----------------------------------------------------------------------===//
2183 // InsertSliceOp
2184 //===----------------------------------------------------------------------===//
2185 
2186 void InsertSliceOp::getAsmResultNames(
2187  function_ref<void(Value, StringRef)> setNameFn) {
2188  setNameFn(getResult(), "inserted_slice");
2189 }
2190 
2191 // Build a InsertSliceOp with mixed static and dynamic entries.
2192 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2193  Value dest, ArrayRef<OpFoldResult> offsets,
2194  ArrayRef<OpFoldResult> sizes,
2195  ArrayRef<OpFoldResult> strides,
2196  ArrayRef<NamedAttribute> attrs) {
2197  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2198  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2199  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2200  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2201  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2202  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2203  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2204  b.getDenseI64ArrayAttr(staticSizes),
2205  b.getDenseI64ArrayAttr(staticStrides));
2206  result.addAttributes(attrs);
2207 }
2208 
2209 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2210 /// Range vector.
2211 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2212  Value dest, ArrayRef<Range> ranges,
2213  ArrayRef<NamedAttribute> attrs) {
2214  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2215  build(b, result, source, dest, offsets, sizes, strides, attrs);
2216 }
2217 
2218 // Build a InsertSliceOp with dynamic entries.
2219 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2220  Value dest, ValueRange offsets, ValueRange sizes,
2221  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2222  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2223  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2224  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2225  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2226  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2227  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2228  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2229 }
2230 
2231 /// Rank-reducing type verification for both InsertSliceOp and
2232 /// ParallelInsertSliceOp.
2234  RankedTensorType srcType, RankedTensorType dstType,
2235  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2236  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2237  // insert_slice is the inverse of extract_slice, use the same type
2238  // inference.
2239  RankedTensorType expected = ExtractSliceOp::inferResultType(
2240  dstType, staticOffsets, staticSizes, staticStrides);
2241  if (expectedType)
2242  *expectedType = expected;
2243  return isRankReducedType(expected, srcType);
2244 }
2245 
2246 /// Verifier for InsertSliceOp.
2248  RankedTensorType expectedType;
2249  SliceVerificationResult result =
2250  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2251  getStaticSizes(), getStaticStrides(), &expectedType);
2252  return produceSliceErrorMsg(result, *this, expectedType);
2253 }
2254 
2255 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2256 /// can mutate the second InsertSliceOp's destination to the first one's.
2257 ///
2258 /// Example:
2259 ///
2260 /// ```mlir
2261 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2262 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2263 /// ```
2264 ///
2265 /// folds into:
2266 ///
2267 /// ```mlir
2268 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2269 /// ```
2270 ///
2271 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2272 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2273  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2274 
2275  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2276  if (!prevInsertOp ||
2277  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2278  !prevInsertOp.isSameAs(insertOp, isSame))
2279  return failure();
2280 
2281  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2282  return success();
2283 }
2284 
2285 /// Folds round-trip extract/insert slice op pairs.
2286 /// Example:
2287 /// ```mlir
2288 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2289 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2290 /// ```
2291 /// can be folded into %val.
2292 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2293  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2294 
2295  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2296  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2297  !extractOp.isSameAs(insertOp, isSame))
2298  return nullptr;
2299 
2300  return extractOp.getSource();
2301 }
2302 
2303 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2304  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2305  getSourceType() == getType() &&
2307  return this->getSource();
2309  return getResult();
2310  if (auto result = foldInsertAfterExtractSlice(*this))
2311  return result;
2312  return OpFoldResult();
2313 }
2314 
2316  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2317  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2318  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2319  return success();
2320 }
2321 
2322 namespace {
2323 /// Pattern to rewrite a insert_slice op with constant arguments.
2324 ///
2325 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2326 template <typename InsertOpTy>
2327 class InsertSliceOpConstantArgumentFolder final
2328  : public OpRewritePattern<InsertOpTy> {
2329 public:
2331 
2332  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2333  PatternRewriter &rewriter) const override {
2334  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2335  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2336  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2337 
2338  // No constant operands were folded, just return;
2339  if (failed(foldDynamicIndexList(mixedOffsets)) &&
2340  failed(foldDynamicIndexList(mixedSizes)) &&
2341  failed(foldDynamicIndexList(mixedStrides)))
2342  return failure();
2343 
2344  // Create the new op in canonical form.
2345  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2346  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2347  mixedOffsets, mixedSizes, mixedStrides);
2348  Value toInsert = insertSliceOp.getSource();
2349  if (sourceType != insertSliceOp.getSourceType()) {
2350  OpBuilder::InsertionGuard g(rewriter);
2351  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2352  // is that the insertion point is just before the ParallelCombiningOp in
2353  // the parallel case.
2354  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2355  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2356  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2357  sourceType, toInsert);
2358  }
2359  rewriter.replaceOpWithNewOp<InsertOpTy>(
2360  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2361  mixedSizes, mixedStrides);
2362  return success();
2363  }
2364 };
2365 
2366 /// Fold tensor_casts with insert_slice operations. If the source or
2367 /// destination tensor is a tensor_cast that removes static type information,
2368 /// the cast is folded into the insert_slice operation. E.g.:
2369 ///
2370 /// ```mlir
2371 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2372 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2373 /// ```
2374 ///
2375 /// folds into:
2376 ///
2377 /// ```mlir
2378 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2379 /// ```
2380 ///
2381 /// Note: When folding a cast on the destination tensor, the result of the
2382 /// insert_slice operation is casted to ensure that the type of the result did
2383 /// not change.
2384 ///
2385 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2386 template <typename InsertOpTy>
2387 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2389 
2390  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2391  PatternRewriter &rewriter) const override {
2392  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2393  return matchPattern(operand, matchConstantIndex());
2394  }))
2395  return failure();
2396 
2397  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2398  auto castOp = v.getDefiningOp<tensor::CastOp>();
2399  if (!castOp || !canFoldIntoConsumerOp(castOp))
2400  return std::nullopt;
2401  return castOp.getSource();
2402  };
2403  std::optional<Value> sourceCastSource =
2404  getSourceOfCastOp(insertSliceOp.getSource());
2405  std::optional<Value> destCastSource =
2406  getSourceOfCastOp(insertSliceOp.getDest());
2407  if (!sourceCastSource && !destCastSource)
2408  return failure();
2409 
2410  auto src =
2411  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2412  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2413  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2414  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2415  if (!srcType || !dstType)
2416  return failure();
2417  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2418  insertSliceOp.getStaticSizes(),
2419  insertSliceOp.getStaticStrides()) !=
2421  return failure();
2422 
2423  Operation *replacement = rewriter.create<InsertOpTy>(
2424  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2425  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2426 
2427  // In the parallel case there is no result and so nothing to cast.
2428  bool isParallelInsert =
2429  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2430  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2431  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2432  insertSliceOp.getDestType(),
2433  replacement->getResult(0));
2434  }
2435  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2436  return success();
2437  }
2438 };
2439 
2440 /// If additional static type information can be deduced from a insert_slice's
2441 /// size operands, insert an explicit cast of the op's source operand. This
2442 /// enables other canonicalization patterns that are matching for tensor_cast
2443 /// ops such as `ForOpTensorCastFolder` in SCF.
2444 ///
2445 /// Example:
2446 ///
2447 /// ```mlir
2448 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2449 /// : tensor<?x?xf32> into ...
2450 /// ```
2451 ///
2452 /// folds into:
2453 ///
2454 /// ```mlir
2455 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2456 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2457 /// : tensor<64x64xf32> into ...
2458 /// ```
2459 ///
2460 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2461 template <typename InsertOpTy>
2462 struct InsertSliceOpSourceCastInserter final
2463  : public OpRewritePattern<InsertOpTy> {
2465 
2466  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2467  PatternRewriter &rewriter) const override {
2468  RankedTensorType srcType = insertSliceOp.getSourceType();
2469  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2470  return failure();
2471  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2472  srcType.getShape().end());
2473  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2474  if (std::optional<int64_t> constInt =
2475  getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
2476  newSrcShape[i] = *constInt;
2477  }
2478 
2479  RankedTensorType newSrcType =
2480  RankedTensorType::get(newSrcShape, srcType.getElementType());
2481  if (srcType == newSrcType ||
2482  !preservesStaticInformation(srcType, newSrcType) ||
2483  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2484  return failure();
2485 
2486  // newSrcType is:
2487  // 1) Different from srcType.
2488  // 2) "More static" than srcType.
2489  // 3) Cast-compatible with srcType.
2490  // Insert the cast.
2491  OpBuilder::InsertionGuard g(rewriter);
2492  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2493  // that the insertion point is just before the ParallelCombiningOp in the
2494  // parallel case.
2495  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2496  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2497  Value cast = rewriter.create<tensor::CastOp>(
2498  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2499  rewriter.replaceOpWithNewOp<InsertOpTy>(
2500  insertSliceOp, cast, insertSliceOp.getDest(),
2501  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2502  insertSliceOp.getMixedStrides());
2503  return success();
2504  }
2505 };
2506 } // namespace
2507 
2508 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2509  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2510 }
2511 
2512 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2513  MLIRContext *context) {
2514  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2515  InsertSliceOpCastFolder<InsertSliceOp>,
2516  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2517 }
2518 
2520  Location loc,
2521  Value tensor,
2522  Value dest) {
2523  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2524  unsigned rank = rankedTensorType.getRank();
2525  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2526  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2527  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2528  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2529  sizes, strides);
2530 }
2531 
2532 //===----------------------------------------------------------------------===//
2533 // PadOp
2534 //===----------------------------------------------------------------------===//
2535 
2536 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2537  setNameFn(getResult(), "padded");
2538 }
2539 
2540 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2541 // supports optional types.
2542 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2543  Type typeToInfer, Type typeToInferFrom) {}
2544 
2547  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2548  Type &typeToInfer, Type typeToInferFrom) {
2549  if (optOperand)
2550  typeToInfer = typeToInferFrom;
2551  return success();
2552 }
2553 
2555  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2556  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2557  auto expectedType =
2558  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2559  if (!expectedType) {
2560  return emitError("failed to infer expectedType from sourceType ")
2561  << sourceType << ", specified resultType is " << resultType;
2562  }
2563  if (resultType.getRank() != expectedType.getRank()) {
2564  return emitError("specified type ")
2565  << resultType << " does not match the inferred type "
2566  << expectedType;
2567  }
2568  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2569  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2570  continue;
2571  if (expectedType.isDynamicDim(i))
2572  continue;
2573  return emitError("specified type ")
2574  << resultType << " does not match the inferred type "
2575  << expectedType;
2576  }
2577 
2578  return success();
2579 }
2580 
2581 LogicalResult PadOp::verifyRegions() {
2582  auto &region = getRegion();
2583  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2584  Block &block = region.front();
2585  if (block.getNumArguments() != rank)
2586  return emitError("expected the block to have ") << rank << " arguments";
2587 
2588  // Note: the number and type of yield values are checked in the YieldOp.
2589  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2590  if (!en.value().isIndex())
2591  return emitOpError("expected block argument ")
2592  << (en.index() + 1) << " to be an index";
2593  }
2594 
2595  // Ensure that the region yields an element of the right type.
2596  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2597  if (yieldOp.getValue().getType() !=
2598  llvm::cast<ShapedType>(getType()).getElementType())
2599  return emitOpError("expected yield type to match shape element type");
2600 
2601  return success();
2602 }
2603 
2604 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2605  ArrayRef<int64_t> staticLow,
2606  ArrayRef<int64_t> staticHigh,
2607  ArrayRef<int64_t> resultShape) {
2608  unsigned rank = sourceType.getRank();
2609  if (staticLow.size() != rank)
2610  return RankedTensorType();
2611  if (staticHigh.size() != rank)
2612  return RankedTensorType();
2613  if (!(resultShape.empty() || resultShape.size() == rank))
2614  return RankedTensorType();
2615 
2616  SmallVector<int64_t, 4> inferredShape;
2617  for (auto i : llvm::seq<unsigned>(0, rank)) {
2618  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2619  staticHigh[i] == ShapedType::kDynamic) {
2620  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2621  : resultShape[i]);
2622  } else {
2623  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2624  assert((resultShape.empty() || size == resultShape[i] ||
2625  resultShape[i] == ShapedType::kDynamic) &&
2626  "mismatch between inferred shape and result shape");
2627  inferredShape.push_back(size);
2628  }
2629  }
2630 
2631  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2632 }
2633 
2634 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2635  Value source, ArrayRef<int64_t> staticLow,
2636  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2637  bool nofold, ArrayRef<NamedAttribute> attrs) {
2638  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2639  if (!resultType)
2640  resultType = inferResultType(sourceType, staticLow, staticHigh);
2641  build(b, result, resultType, source, low, high,
2642  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2643  nofold ? b.getUnitAttr() : UnitAttr());
2644  result.addAttributes(attrs);
2645 }
2646 
2647 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2648  Value source, ValueRange low, ValueRange high, bool nofold,
2649  ArrayRef<NamedAttribute> attrs) {
2650  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2651  unsigned rank = sourceType.getRank();
2652  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2653  build(b, result, resultType, source, staticVector, staticVector, low, high,
2654  nofold, attrs);
2655 }
2656 
2657 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2658  Value source, ArrayRef<OpFoldResult> low,
2659  ArrayRef<OpFoldResult> high, bool nofold,
2660  ArrayRef<NamedAttribute> attrs) {
2661  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2662  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2663  SmallVector<int64_t, 4> staticLow, staticHigh;
2664  // staticLow and staticHigh have full information of the padding config.
2665  // This will grow staticLow and staticHigh with 1 value. If the config is
2666  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
2667  // value as well.
2668  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
2669  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
2670  if (!resultType) {
2671  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2672  }
2673  assert(llvm::isa<RankedTensorType>(resultType));
2674  build(b, result, resultType, source, dynamicLow, dynamicHigh,
2675  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2676  nofold ? b.getUnitAttr() : UnitAttr());
2677  result.addAttributes(attrs);
2678 }
2679 
2680 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2681  Value source, ArrayRef<OpFoldResult> low,
2682  ArrayRef<OpFoldResult> high, Value constantPadValue,
2683  bool nofold, ArrayRef<NamedAttribute> attrs) {
2684  build(b, result, resultType, source, low, high, nofold, attrs);
2685 
2686  // Add a region and a block to yield the pad value.
2687  Region *region = result.regions[0].get();
2688  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
2689  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
2690  SmallVector<Location> blockArgLocs(sourceRank, result.location);
2691 
2692  // `builder.createBlock` changes the insertion point within the block. Create
2693  // a guard to reset the insertion point of the builder after it is destroyed.
2694  OpBuilder::InsertionGuard guard(b);
2695  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
2696  b.create<tensor::YieldOp>(result.location, constantPadValue);
2697 }
2698 
2699 llvm::SmallBitVector PadOp::getPaddedDims() {
2700  llvm::SmallBitVector paddedDims(getSourceType().getRank());
2701  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
2702  for (const auto &en : enumerate(paddingWidths))
2703  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
2704  paddedDims.set(en.index());
2705  };
2706  extractPaddedDims(getMixedLowPad());
2707  extractPaddedDims(getMixedHighPad());
2708  return paddedDims;
2709 }
2710 
2711 namespace {
2712 // Folds tensor.pad when padding is static zeros and the attribute
2713 // doesn't request otherwise.
2714 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
2716 
2717  LogicalResult matchAndRewrite(PadOp padTensorOp,
2718  PatternRewriter &rewriter) const override {
2719  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2720  return failure();
2721  if (padTensorOp.getNofold())
2722  return failure();
2723  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2724  padTensorOp, padTensorOp.getResult().getType(),
2725  padTensorOp.getSource());
2726  return success();
2727  }
2728 };
2729 
2730 // Fold CastOp into PadOp when adding static information.
2731 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
2733 
2734  LogicalResult matchAndRewrite(PadOp padTensorOp,
2735  PatternRewriter &rewriter) const override {
2736  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2737  if (!tensor::canFoldIntoConsumerOp(castOp))
2738  return failure();
2739 
2740  auto newResultType = PadOp::inferResultType(
2741  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
2742  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2743  padTensorOp.getResultType().getShape());
2744 
2745  if (newResultType == padTensorOp.getResultType()) {
2746  rewriter.updateRootInPlace(padTensorOp, [&]() {
2747  padTensorOp.getSourceMutable().assign(castOp.getSource());
2748  });
2749  } else {
2750  auto newOp = rewriter.create<PadOp>(
2751  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2752  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2753  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2754  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2755  IRMapping mapper;
2756  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2757 
2758  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2759  padTensorOp, padTensorOp.getResultType(), newOp);
2760  }
2761  return success();
2762  }
2763 };
2764 
2765 // Fold CastOp using the result of PadOp back into the latter if it adds
2766 // static information.
2767 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2769 
2770  LogicalResult matchAndRewrite(PadOp padTensorOp,
2771  PatternRewriter &rewriter) const override {
2772  if (!padTensorOp.getResult().hasOneUse())
2773  return failure();
2774  auto tensorCastOp =
2775  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2776  if (!tensorCastOp)
2777  return failure();
2778  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2779  tensorCastOp.getDest().getType()))
2780  return failure();
2781 
2782  auto replacementOp = rewriter.create<PadOp>(
2783  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2784  padTensorOp.getSource(), padTensorOp.getStaticLow(),
2785  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2786  padTensorOp.getHigh(), padTensorOp.getNofold(),
2787  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2788  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2789 
2790  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2791  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2792  return success();
2793  }
2794 };
2795 
2796 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2797 /// different dimensions. The pattern applies if the following preconditions
2798 /// hold:
2799 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2800 /// 2) the tensor::ExtractSliceOps have only unit-strides,
2801 /// 3) the tensor::PadOps perform only high-padding,
2802 /// 4) the tensor::PadOps have the same constant padding value,
2803 /// 5) the tensor::PadOps do not have common padding dimensions,
2804 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2805 /// zero-offset for every dimension.
2806 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
2807 /// the
2808 /// padded source dimensions.
2809 ///
2810 /// Example:
2811 ///
2812 /// ```mlir
2813 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2814 /// : tensor<64x64xf32> to tensor<?x64xf32>
2815 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2816 /// } : tensor<?x64xf32> to tensor<8x64xf32>
2817 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2818 /// : tensor<8x64xf32> to tensor<8x?xf32>
2819 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2820 /// } : tensor<8x?xf32> to tensor<8x4xf32>
2821 /// ```
2822 ///
2823 /// folds into:
2824 ///
2825 /// ```mlir
2826 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2827 /// : tensor<64x64xf32> to tensor<?x?xf32>
2828 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2829 /// } : tensor<?x?xf32> to tensor<8x4xf32>
2830 /// ```
2831 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2833 
2834  LogicalResult matchAndRewrite(PadOp padOp,
2835  PatternRewriter &rewriter) const override {
2836  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2837  if (!innerSliceOp)
2838  return failure();
2839  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2840  if (!outerPadOp || outerPadOp.getNofold())
2841  return failure();
2842  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2843  if (!outerSliceOp)
2844  return failure();
2845 
2846  // 1) Fail if the chain is rank-reducing.
2847  int64_t rank = padOp.getSourceType().getRank();
2848  if (outerSliceOp.getSourceType().getRank() != rank) {
2849  return rewriter.notifyMatchFailure(padOp,
2850  "cannot fold rank-reducing chain");
2851  }
2852 
2853  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2854  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2855  return rewriter.notifyMatchFailure(
2856  padOp, "cannot fold non-unit stride ExtractSliceOps");
2857  }
2858 
2859  // 3) Fail if the tensor::PadOps have non-zero low padding.
2860  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2861  return rewriter.notifyMatchFailure(padOp,
2862  "cannot fold PadOps with low padding");
2863  }
2864 
2865  // 4) Fail if the tensor::PadOps padding values do not match.
2866  Attribute innerAttr, outerAttr;
2867  Value innerValue = padOp.getConstantPaddingValue();
2868  Value outerValue = outerPadOp.getConstantPaddingValue();
2869  if (!innerValue || !outerValue ||
2870  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2871  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2872  innerAttr != outerAttr) {
2873  return rewriter.notifyMatchFailure(
2874  padOp, "cannot fold PadOps with different padding values");
2875  }
2876 
2877  // 5) Fail if a dimension is padded by both tensor::PadOps.
2878  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2879  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2880  if (innerDims.anyCommon(outerDims)) {
2881  return rewriter.notifyMatchFailure(
2882  padOp, "cannot fold PadOps with common padding dimensions");
2883  }
2884 
2885  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2886  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2887  // for every dimension, and use the offset the other pair. Fail if no
2888  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2889  // exists.
2890  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2891  for (auto en : enumerate(newOffsets)) {
2892  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2893  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2894  if (!innerDims.test(en.index()) &&
2895  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2896  en.value() = outerOffset;
2897  continue;
2898  }
2899  if (!outerDims.test(en.index()) &&
2900  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2901  en.value() = innerOffset;
2902  continue;
2903  }
2904  return rewriter.notifyMatchFailure(
2905  padOp, "cannot find zero-offset and zero-padding pair");
2906  }
2907 
2908  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
2909  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
2910  // outer tensor::PadOp and fail if the size of the inner
2911  // tensor::ExtractSliceOp does not match the size of the padded dimension.
2912  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
2913  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2914  for (auto en : enumerate(newSizes)) {
2915  if (!outerDims.test(en.index()))
2916  continue;
2917  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2918  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2919  assert(!ShapedType::isDynamic(sourceSize) &&
2920  "expected padded dimension to have a static size");
2921  if (getConstantIntValue(sliceSize) != sourceSize) {
2922  return rewriter.notifyMatchFailure(
2923  padOp, "cannot fold since the inner ExtractSliceOp size does not "
2924  "match the size of the outer padding");
2925  }
2926  en.value() = outerSliceOp.getMixedSizes()[en.index()];
2927  }
2928 
2929  // Combine the high paddings of the two tensor::PadOps.
2930  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2931  for (auto en : enumerate(newHighPad)) {
2932  if (innerDims.test(en.index()))
2933  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2934  if (outerDims.test(en.index()))
2935  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2936  }
2937 
2938  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
2939  // the two paddings in one step.
2940  auto newSliceOp = rewriter.create<ExtractSliceOp>(
2941  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2942  innerSliceOp.getMixedStrides());
2943  auto newPadOp = rewriter.create<PadOp>(
2944  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2945  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
2946  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
2947  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2948  newPadOp.getRegion().begin());
2949  rewriter.replaceOp(padOp, newPadOp.getResult());
2950  return success();
2951  }
2952 };
2953 
2954 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
2956 
2957  LogicalResult matchAndRewrite(PadOp padTensorOp,
2958  PatternRewriter &rewriter) const override {
2959  Value input = padTensorOp.getSource();
2960  if (!llvm::isa<RankedTensorType>(input.getType()))
2961  return failure();
2962  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
2963  auto inputRank = inputDims.size();
2964 
2965  auto oldResultType =
2966  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
2967  if (!oldResultType)
2968  return failure();
2969 
2970  auto outputDims = oldResultType.getShape();
2971 
2972  // Extract the static info from the high and low operands.
2973  SmallVector<int64_t> constOperandsLow;
2974  for (auto operand : padTensorOp.getLow()) {
2975  APSInt intOp;
2976  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
2977  constOperandsLow.push_back(ShapedType::kDynamic);
2978  continue;
2979  }
2980  constOperandsLow.push_back(intOp.getExtValue());
2981  }
2982  SmallVector<int64_t> constOperandsHigh;
2983  for (auto operand : padTensorOp.getHigh()) {
2984  APSInt intOp;
2985  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
2986  constOperandsHigh.push_back(ShapedType::kDynamic);
2987  continue;
2988  }
2989  constOperandsHigh.push_back(intOp.getExtValue());
2990  }
2991 
2992  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
2993  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
2994 
2995  // Verify the op is well-formed.
2996  if (inputDims.size() != outputDims.size() ||
2997  inputDims.size() != constLow.size() ||
2998  inputDims.size() != constHigh.size())
2999  return failure();
3000 
3001  auto lowCount = 0;
3002  auto highCount = 0;
3003  for (size_t i = 0; i < inputRank; i++) {
3004  if (constLow[i] == ShapedType::kDynamic)
3005  constLow[i] = constOperandsLow[lowCount++];
3006  if (constHigh[i] == ShapedType::kDynamic)
3007  constHigh[i] = constOperandsHigh[highCount++];
3008  }
3009 
3010  auto staticLow = ArrayRef<int64_t>(constLow);
3011  auto staticHigh = ArrayRef<int64_t>(constHigh);
3012 
3013  // Calculate the output sizes with the static information.
3014  SmallVector<int64_t> newOutDims;
3015  for (size_t i = 0; i < inputRank; i++) {
3016  if (outputDims[i] == ShapedType::kDynamic) {
3017  newOutDims.push_back(
3018  (staticLow[i] == ShapedType::kDynamic ||
3019  staticHigh[i] == ShapedType::kDynamic ||
3020  inputDims[i] == ShapedType::kDynamic
3021  ? ShapedType::kDynamic
3022  : inputDims[i] + staticLow[i] + staticHigh[i]));
3023  } else {
3024  newOutDims.push_back(outputDims[i]);
3025  }
3026  }
3027 
3028  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3029  llvm::all_of(newOutDims,
3030  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3031  return failure();
3032 
3033  // Rewrite the op using the new static type.
3034  auto newResultType = RankedTensorType::get(
3035  newOutDims, padTensorOp.getType().getElementType());
3036  auto newOp = rewriter.create<PadOp>(
3037  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3038  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3039  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3040 
3041  IRMapping mapper;
3042  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3043  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3044  newOp);
3045 
3046  return success();
3047  }
3048 };
3049 
3050 } // namespace
3051 
3052 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3053  MLIRContext *context) {
3054  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3055  FoldOrthogonalPaddings, FoldStaticPadding>(context);
3056 }
3057 
3058 /// Return the padding value of the PadOp if it constant. In this context,
3059 /// "constant" means an actual constant or "defined outside of the block".
3060 ///
3061 /// Values are considered constant in three cases:
3062 /// - A ConstantLike value.
3063 /// - A basic block argument from a different block.
3064 /// - A value defined outside of the block.
3065 ///
3066 /// If the padding value is not constant, an empty Value is returned.
3067 Value PadOp::getConstantPaddingValue() {
3068  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3069  if (!yieldOp)
3070  return {};
3071  Value padValue = yieldOp.getValue();
3072  // Check if yield value is a constant.
3073  if (matchPattern(padValue, m_Constant()))
3074  return padValue;
3075  // Check if yield value is defined inside the PadOp block.
3076  if (padValue.getParentBlock() == &getRegion().front())
3077  return {};
3078  // Else: Yield value defined outside of the PadOp block.
3079  return padValue;
3080 }
3081 
3082 OpFoldResult PadOp::fold(FoldAdaptor) {
3083  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3084  !getNofold())
3085  return getSource();
3086  return {};
3087 }
3088 
3089 //===----------------------------------------------------------------------===//
3090 // ParallelInsertSliceOp
3091 //===----------------------------------------------------------------------===//
3092 
3093 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3094  ParallelCombiningOpInterface parallelCombiningParent =
3095  getParallelCombiningParent();
3096  for (const auto &it :
3097  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3098  Operation &nextOp = it.value();
3099  if (&nextOp == getOperation())
3100  return parallelCombiningParent.getParentResult(it.index());
3101  }
3102  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3103 }
3104 
3105 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3106 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3107  Value source, Value dest,
3108  ArrayRef<OpFoldResult> offsets,
3109  ArrayRef<OpFoldResult> sizes,
3110  ArrayRef<OpFoldResult> strides,
3111  ArrayRef<NamedAttribute> attrs) {
3112  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3113  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3114  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3115  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3116  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3117  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3118  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3119  b.getDenseI64ArrayAttr(staticSizes),
3120  b.getDenseI64ArrayAttr(staticStrides));
3121  result.addAttributes(attrs);
3122 }
3123 
3124 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3125 /// packed into a Range vector.
3126 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3127  Value source, Value dest,
3128  ArrayRef<Range> ranges,
3129  ArrayRef<NamedAttribute> attrs) {
3130  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3131  build(b, result, source, dest, offsets, sizes, strides, attrs);
3132 }
3133 
3134 // Build a ParallelInsertSliceOp with dynamic entries.
3135 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3136  Value source, Value dest, ValueRange offsets,
3137  ValueRange sizes, ValueRange strides,
3138  ArrayRef<NamedAttribute> attrs) {
3139  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3140  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3141  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3142  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3143  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3144  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3145  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3146 }
3147 
3149  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3150  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3151  << *(getOperation()->getParentOp());
3152 
3153  RankedTensorType expectedType;
3154  SliceVerificationResult result =
3155  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3156  getStaticSizes(), getStaticStrides(), &expectedType);
3157  return produceSliceErrorMsg(result, *this, expectedType);
3158 }
3159 
3160 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3161  RewritePatternSet &results, MLIRContext *context) {
3162  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3163  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3164  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3165 }
3166 
3167 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3168  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3169 }
3170 
3171 //===----------------------------------------------------------------------===//
3172 // ScatterOp
3173 //===----------------------------------------------------------------------===//
3174 
3175 void ScatterOp::getAsmResultNames(
3176  function_ref<void(Value, StringRef)> setNameFn) {
3177  setNameFn(getResult(), "scatter");
3178 }
3179 
3181  int64_t destRank = getDestType().getRank();
3182  ArrayRef<int64_t> scatterDims = getScatterDims();
3183  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3184  "scatter", "dest")))
3185  return failure();
3186 
3187  if (!getUnique())
3188  return emitOpError("requires 'unique' attribute to be set");
3189  // TODO: we could also check statically that there are fewer leading index
3190  // tensor dims than the dest dims. If this is not the case, the unique
3191  // attribute cannot be true.
3192 
3193  // Use the GatherOp::inferResultType on the `dest` type and verify the
3194  // expected type matches the source type.
3195  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3196  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3197  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3198  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3199  if (getSourceType() != expectedSourceType &&
3200  getSourceType() != expectedRankReducedSourceType) {
3201  return emitOpError("source type "
3202  "mismatch: "
3203  "expected ")
3204  << expectedSourceType << " or its rank-reduced variant "
3205  << expectedRankReducedSourceType << " (got: " << getSourceType()
3206  << ")";
3207  }
3208 
3209  return success();
3210 }
3211 
3212 //===----------------------------------------------------------------------===//
3213 // SplatOp
3214 //===----------------------------------------------------------------------===//
3215 
3216 void SplatOp::getAsmResultNames(
3217  function_ref<void(Value, StringRef)> setNameFn) {
3218  setNameFn(getResult(), "splat");
3219 }
3220 
3221 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3222  auto constOperand = adaptor.getInput();
3223  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3224  return {};
3225 
3226  // SplatElementsAttr::get treats single value for second arg as being a
3227  // splat.
3228  return SplatElementsAttr::get(getType(), {constOperand});
3229 }
3230 
3231 //===----------------------------------------------------------------------===//
3232 // PackOp/UnPackOp Common
3233 //===----------------------------------------------------------------------===//
3234 
3235 namespace {
3236 
3237 /// Packing one-dimensional tensor can be expressed as an expand shape op.
3238 struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
3240 
3241  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
3242  Type newOperandType, ArrayAttr reassociation) const {
3243  if (operand.getType() == newOperandType)
3244  return operand;
3245  return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3246  reassociation);
3247  }
3248 
3249  LogicalResult matchAndRewrite(PackOp packOp,
3250  PatternRewriter &rewriter) const override {
3251  RankedTensorType sourceType = packOp.getSourceType();
3252  RankedTensorType destType = packOp.getDestType();
3253  if (sourceType.getRank() != 1 || packOp.getPaddingValue())
3254  return failure();
3255  auto reassociation =
3256  getReassociationIndicesForReshape(sourceType, destType);
3257  if (!reassociation)
3258  return failure();
3259  Value expanded = insertExpand(
3260  rewriter, packOp.getLoc(), packOp.getSource(), destType,
3261  getReassociationIndicesAttribute(rewriter, *reassociation));
3262  rewriter.replaceOp(packOp, expanded);
3263  return success();
3264  }
3265 };
3266 
3267 } // namespace
3268 
3270  patterns.add<SimplifyPackToExandShape>(patterns.getContext());
3271 }
3272 
3273 template <typename OpTy>
3274 static LogicalResult
3276  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3277  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3278  "applies to only pack or unpack operations");
3279  int64_t destRank = op.getDestRank();
3280  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3281  reifiedReturnShapes[0] =
3282  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3283  return success();
3284 }
3285 
3286 template <typename OpTy>
3288  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3289  "applies to only pack or unpack operations");
3290  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3291  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3292  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3293  assert(tiles.size() == dimsToTile.size() &&
3294  "tiles must match indices of dimension to block");
3295  // bind the dimension `i` with the tile factor.
3296  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3297  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3298  return dimAndTileMapping;
3299 }
3300 
3301 template <typename OpTy>
3303  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3304  "applies to only pack or unpack operations");
3305  Builder builder(op);
3306  SmallVector<OpFoldResult> mixedInnerTiles;
3307  unsigned dynamicValIndex = 0;
3308  for (int64_t staticTile : op.getStaticInnerTiles()) {
3309  if (!ShapedType::isDynamic(staticTile))
3310  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3311  else
3312  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3313  }
3314  return mixedInnerTiles;
3315 }
3316 
3317 template <typename OpTy>
3319  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3320  "applies to only pack or unpack operations");
3321  SmallVector<Value> dynamicTiles;
3322  SmallVector<int64_t> staticTiles;
3323  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3324  return staticTiles;
3325 }
3326 
3327 /// Returns true if `dimsPos` is invalid. It is invalid when:
3328 /// a) It contains duplicate.
3329 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3330 /// c) The number of elements in `dimsPos` is > than `rank`.
3332  size_t rank) {
3333  size_t dimsPosSize = dimsPos.size();
3334  if (dimsPosSize > rank)
3335  return true;
3336  DenseSet<int64_t> uniqued;
3337  for (int64_t dim : dimsPos)
3338  uniqued.insert(dim);
3339  if (dimsPosSize != uniqued.size())
3340  return true;
3341  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3342  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3343  });
3344 }
3345 
3346 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3347 /// of the `limitShape`.
3348 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3349  ArrayRef<int64_t> limitShape) {
3350  assert(
3351  sourceShape.size() == limitShape.size() &&
3352  "expected source shape rank, and limit of the shape to have same rank");
3353  return llvm::all_of(
3354  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3355  int64_t sourceExtent = std::get<0>(it);
3356  int64_t limit = std::get<1>(it);
3357  return ShapedType::isDynamic(sourceExtent) ||
3358  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3359  });
3360 }
3361 
3362 template <typename OpTy>
3364  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3365  "applies to only pack or unpack operations");
3366  Operation *op = packOrUnPack.getOperation();
3367 
3368  // Return true if we have a zero-value tile.
3369  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3370  return llvm::any_of(
3371  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3372  };
3373 
3374  // Verify tiles. Do not allow zero tiles.
3375  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3376  if (hasZeros(mixedTiles))
3377  return op->emitError("invalid zero tile factor");
3378 
3379  // Verify inner_dims_pos and outer_dims_perm.
3380  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3381  ? packOrUnPack.getSourceType()
3382  : packOrUnPack.getDestType();
3383  size_t unpackedRank = unpackedType.getRank();
3384  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3385  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3386  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3387  return op->emitError("invalid inner_dims_pos vector");
3388  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3389  return op->emitError("invalid outer_dims_perm vector");
3390  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3391  return op->emitError("outer_dims_perm must be a permutation or empty");
3392 
3393  // Tiling factors must be less than or equal to the input rank for pack (or
3394  // output rank for unpack), and must match the number of `inner_dims_pos`.
3395  if (mixedTiles.size() > unpackedRank) {
3396  return op->emitError("tiling factors must be less than or equal to the "
3397  "input rank for pack or output rank for unpack");
3398  }
3399  if (mixedTiles.size() != innerDimsPos.size()) {
3400  return op->emitError(
3401  "tiling factors must equal the number of dimensions to tile");
3402  }
3403 
3404  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3405  ? packOrUnPack.getDestType()
3406  : packOrUnPack.getSourceType();
3407  size_t packedRank = packedType.getRank();
3408  // Require output rank to match input rank + number of blocking factors.
3409  if (unpackedRank + mixedTiles.size() != packedRank) {
3410  return op->emitError(
3411  "packed rank must equal unpacked rank + tiling factors");
3412  }
3413 
3414  // Verify result shape is greater than the minimum expected
3415  // by the pack operation, and that the output shape
3416  // represents full tiles.
3417  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3418  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3419  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3420  return op->emitError("the shape of output is not large enough to hold the "
3421  "packed data. Expected at least ")
3422  << expectedPackedType << ", got " << packedType;
3423  }
3424  if (!llvm::all_of(
3425  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3426  mixedTiles),
3427  [](std::tuple<int64_t, OpFoldResult> it) {
3428  std::optional<int64_t> constTileSize =
3429  getConstantIntValue(std::get<1>(it));
3430  int64_t shape = std::get<0>(it);
3431  if (!constTileSize) {
3432  // If specified tile size is dynamic, output shape should
3433  // be dynamic too.
3434  return ShapedType::isDynamic(shape);
3435  }
3436  if (ShapedType::isDynamic(shape)) {
3437  // For the shape being dynamic when tile size is
3438  // specified, return true. In canonical form a constant
3439  // tile size should lead to constant shape of the tiled
3440  // dimension, but not needed for verification.
3441  return true;
3442  }
3443  return shape == constTileSize.value();
3444  })) {
3445  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3446  "tiled dimension in the packed type");
3447  }
3448  return success();
3449 }
3450 
3451 namespace {
3452 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3453 /// various permutations to the op.
3454 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3455 // these. These may or may not become true foldings / canonicalizations
3456 // depending on how aggressive we want to be in automatically folding
3457 // transposes.
3458 struct PackOrUnPackTransposeResult {
3459  SmallVector<int64_t> innerDimsPos;
3460  SmallVector<OpFoldResult> innerTiles;
3461  SmallVector<int64_t> outerDimsPerm;
3462 };
3463 } // namespace
3464 
3465 template <typename OpTy>
3466 static PackOrUnPackTransposeResult
3468  ArrayRef<int64_t> innerPermutation,
3469  ArrayRef<int64_t> outerPermutation) {
3470  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3471  "applies to only pack or unpack operations");
3472  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3473  "some permutation must be non-empty");
3474  PackOrUnPackTransposeResult metadata;
3475  metadata.innerDimsPos =
3476  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3477  metadata.innerTiles =
3478  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3479  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3480  ? packOrUnPackOp.getSourceRank()
3481  : packOrUnPackOp.getDestRank();
3482  metadata.outerDimsPerm =
3483  packOrUnPackOp.getOuterDimsPerm().empty()
3484  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3485  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3486  if (!innerPermutation.empty()) {
3487  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3488  isPermutationVector(innerPermutation) &&
3489  "invalid inner permutation");
3490  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3491  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3492  }
3493  if (!outerPermutation.empty()) {
3494  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3495  isPermutationVector(outerPermutation) &&
3496  "invalid outer permutation");
3497  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3498  }
3499  return metadata;
3500 }
3501 
3502 //===----------------------------------------------------------------------===//
3503 // PackOp
3504 //===----------------------------------------------------------------------===//
3505 
3506 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3507  setNameFn(getResult(), "pack");
3508 }
3509 
3510 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3511  Value dest, ArrayRef<int64_t> innerDimsPos,
3512  ArrayRef<OpFoldResult> innerTiles,
3513  std::optional<Value> paddingValue,
3514  ArrayRef<int64_t> outerDimsPerm) {
3515  assert(innerDimsPos.size() == innerTiles.size() &&
3516  "number of tile sizes specified must match the specified number of "
3517  "original dimensions to be tiled");
3518  SmallVector<int64_t> staticTileSizes;
3519  SmallVector<Value> dynamicTileSizes;
3520  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3521  build(builder, state, dest.getType(), source, dest,
3522  paddingValue ? *paddingValue : nullptr,
3523  outerDimsPerm.empty() ? nullptr
3524  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3525  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3526  builder.getDenseI64ArrayAttr(staticTileSizes));
3527 }
3528 
3531  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3532  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3533 }
3534 
3535 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3536  return getDimAndTileMappingImpl(*this);
3537 }
3538 
3539 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3540  return getMixedTilesImpl(*this);
3541 }
3542 
3543 SmallVector<int64_t> PackOp::getStaticTiles() {
3544  return getStaticTilesImpl(*this);
3545 }
3546 
3547 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3548  ArrayRef<int64_t> innerDimsPos,
3549  ArrayRef<OpFoldResult> innerTiles) {
3550  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3551  if (ShapedType::isDynamic(inputShape[pos]))
3552  continue;
3553  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3554  if (!constantTile)
3555  continue;
3556  if (inputShape[pos] % (*constantTile) != 0)
3557  return true;
3558  }
3559  return false;
3560 }
3561 
3564  return failure();
3565 
3566  // Verify padding value, and bail out if the tile does not divide the
3567  // dimension fully. In the case of dynamic tile factors or dimensions, having
3568  // a partial tile is undefined behavior.
3569  auto paddingValue = getPaddingValue();
3570  if (paddingValue &&
3571  paddingValue.getType() != getSourceType().getElementType()) {
3572  return emitOpError("expected padding_value has ")
3573  << getSourceType().getElementType()
3574  << " but got: " << paddingValue.getType();
3575  }
3576 
3577  if (!paddingValue &&
3578  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3579  getMixedTiles())) {
3580  return emitOpError("invalid tile factor provided. Only full tiles are "
3581  "supported when padding_value is not set");
3582  }
3583  return success();
3584 }
3585 
3586 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3587 /// Value's to kDynamic, even if they are arith.constant values.
3588 static SmallVector<int64_t>
3590  SmallVector<int64_t> result;
3591  for (auto o : ofrs) {
3592  // Have to do this first, as getConstantIntValue special-cases constants.
3593  if (llvm::dyn_cast_if_present<Value>(o))
3594  result.push_back(ShapedType::kDynamic);
3595  else
3596  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3597  }
3598  return result;
3599 }
3600 
3601 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3602 /// the packed type. Having a shared helper helps implement these two methods in
3603 /// a way that ensures that they agree on which dimensions are dynamic.
3605  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3606  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3607  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3608  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3609  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3610  continue;
3611  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3612  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3613  continue;
3614  }
3615  resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
3616  innerTileSizes[tiledDim.index()]);
3617  }
3618 
3619  // Swap tile loops if outer_dims_perm is available.
3620  if (!outerDimsPerm.empty())
3621  applyPermutationToVector(resultShape, outerDimsPerm);
3622 
3623  // Append the inner tile dimensions.
3624  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3625  return resultShape;
3626 }
3627 
3628 SmallVector<OpFoldResult> PackOp::getResultShape(
3629  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
3630  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
3631  ArrayRef<int64_t> outerDimsPerm) {
3632  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
3633 
3634  AffineExpr s0, s1;
3635  bindSymbols(builder.getContext(), s0, s1);
3636  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
3637  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3638  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
3639  builder, loc, ceilDivExpr,
3640  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3641  }
3642  if (!outerDimsPerm.empty())
3643  applyPermutationToVector(resultDims, outerDimsPerm);
3644  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3645 
3646  SmallVector<int64_t> resultTypeShape =
3648  asShapeWithAnyValueAsDynamic(innerTileSizes),
3649  innerDimsPos, outerDimsPerm);
3650 
3651  // Fix-up `resultDims` to ensure that they are Value's if and only if the
3652  // result type shape says it's a dynamic dim. This is needed as callers may
3653  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
3654  // dynamic dims returned by that.
3655  for (unsigned i = 0; i < resultDims.size(); ++i) {
3656  if (!ShapedType::isDynamic(resultTypeShape[i]))
3657  continue;
3658  resultDims[i] =
3659  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
3660  }
3661 
3662  return resultDims;
3663 }
3664 
3665 /// Get the expected packed type based on source type, tile factors, position of
3666 /// the inner tiles and permutation of the outer tiled loop.
3667 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
3668  ArrayRef<int64_t> innerTileSizes,
3669  ArrayRef<int64_t> innerDimsPos,
3670  ArrayRef<int64_t> outerDimsPerm) {
3672  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3673  return RankedTensorType::get(resultShape, sourceType.getElementType());
3674 }
3675 
3676 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
3677  ArrayRef<OpFoldResult> innerTileSizes,
3678  ArrayRef<int64_t> innerDimsPos,
3679  ArrayRef<int64_t> outerDimsPerm) {
3680  AffineExpr dim0, dim1;
3681  bindDims(b.getContext(), dim0, dim1);
3682  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
3683  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
3684  {v1, v2});
3685  };
3686 
3687  SmallVector<OpFoldResult> mixedSizes;
3688  for (auto [index, value] : llvm::enumerate(
3689  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
3690  if (ShapedType::isDynamic(value))
3691  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
3692  else
3693  mixedSizes.push_back(b.getIndexAttr(value));
3694  }
3695  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3696  int64_t dimPos = std::get<0>(it);
3697  OpFoldResult tileSize = std::get<1>(it);
3698  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
3699  }
3700  if (!outerDimsPerm.empty())
3701  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3702 
3703  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3704  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3705  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3706 }
3707 
3708 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
3709  ArrayRef<int64_t> innerPermutation,
3710  ArrayRef<int64_t> outerPermutation) {
3711  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3712  *this, innerPermutation, outerPermutation);
3713  Value transposedDest =
3714  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3715  metadata.innerDimsPos, metadata.outerDimsPerm);
3716  return b.create<PackOp>(loc, getSource(), transposedDest,
3717  metadata.innerDimsPos, metadata.innerTiles,
3718  getPaddingValue(), metadata.outerDimsPerm);
3719 }
3720 
3721 /// Returns true if the tiles and the tiled dims are constant.
3722 template <typename OpTy>
3724  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3725  "applies to only pack or unpack operations");
3726  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3727  ? op.getDestType()
3728  : op.getSourceType();
3729  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
3730  for (auto [dimDest, tile] : llvm::zip(
3731  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3732  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
3733  if (!constTileSize || ShapedType::isDynamic(dimDest))
3734  return false;
3735  }
3736  return true;
3737 }
3738 
3739 Speculation::Speculatability PackOp::getSpeculatability() {
3740  if (getPaddingValue())
3742 
3743  // The verifier rejects already operations if we can statically prove that the
3744  // sizes of the tiles do not divide perfectly the dimension; thus, check only
3745  // to have constant tiles and tiled inner dimensions.
3746  if (!areTilesAndTiledDimsAllConstant(*this))
3748 
3750 }
3751 
3752 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
3753 // dimensions for pack and unpack.
3754 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
3755  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3756  return false;
3757  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3758 }
3759 
3760 // Return true if pack and unpack have the same tiles.
3761 // Same SSA values or same integer constants.
3762 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
3763  auto packTiles = packOp.getMixedTiles();
3764  auto unPackTiles = unPackOp.getMixedTiles();
3765  if (packTiles.size() != unPackTiles.size())
3766  return false;
3767  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
3768  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
3769  return false;
3770  }
3771  return true;
3772 }
3773 
3774 /// Fold an unpack(pack(x)) to x.
3775 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
3776  UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3777  if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3778  return failure();
3779  if (packOp.getPaddingValue() ||
3780  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
3781  !haveSameTiles(packOp, unPackOp))
3782  return failure();
3783  rewriter.replaceOp(packOp, unPackOp.getSource());
3784  return success();
3785 }
3786 
3787 template <typename PackOrUnpackOp>
3788 static bool isLikePadUnPad(PackOrUnpackOp packOp,
3789  RankedTensorType packedTensorType) {
3790  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
3791  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
3792  "Function meant for pack/unpack");
3793  // This is a pad if packing only adds ones and we don't transpose dimensions.
3794 
3795  // Check that we are not transposing any dimensions.
3796  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
3797  int64_t numPackedDims = innerDimsPos.size();
3798  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
3799  if (orderedDims != innerDimsPos) {
3800  // Dimensions don't happen in order.
3801  return false;
3802  }
3803 
3804  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
3805  int64_t packedRank = packedTensorType.getRank();
3806  // At this point we know that we are taking numPackedDims outer
3807  // dimensions and pushing them all the way as the inner most dimensions.
3808  // What's left on the outer most dimensions is, in this order:
3809  // - the factor of the packed dimensions, then
3810  // - the untouched dimensions
3811  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
3812  // if all the dimensions that bubble outerward are ones.
3813  // Therefore check that all the dimensions but the numPackedDims inner most
3814  // ones are ones.
3815  return llvm::all_of(
3816  llvm::seq<int64_t>(0, packedRank - numPackedDims),
3817  [&packedShape](int64_t i) { return packedShape[i] == 1; });
3818 }
3819 
3820 bool PackOp::isLikePad() {
3821  auto packedTensorType =
3822  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
3823  return isLikePadUnPad(*this, packedTensorType);
3824 }
3825 
3826 //===----------------------------------------------------------------------===//
3827 // UnPackOp
3828 //===----------------------------------------------------------------------===//
3829 
3830 void UnPackOp::getAsmResultNames(
3831  function_ref<void(Value, StringRef)> setNameFn) {
3832  setNameFn(getResult(), "unpack");
3833 }
3834 
3837  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3838  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3839 }
3840 
3841 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
3842  return getDimAndTileMappingImpl(*this);
3843 }
3844 
3845 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
3846  return getMixedTilesImpl(*this);
3847 }
3848 
3849 SmallVector<int64_t> UnPackOp::getStaticTiles() {
3850  return getStaticTilesImpl(*this);
3851 }
3852 
3854  return commonVerifierPackAndUnPackOp(*this);
3855 }
3856 
3857 Speculation::Speculatability UnPackOp::getSpeculatability() {
3858  // See PackOp::getSpeculatability.
3859  if (!areTilesAndTiledDimsAllConstant(*this))
3861 
3863 }
3864 
3865 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
3866  Value dest, ArrayRef<int64_t> innerDimsPos,
3867  ArrayRef<OpFoldResult> innerTiles,
3868  ArrayRef<int64_t> outerDimsPerm) {
3869  assert(innerDimsPos.size() == innerTiles.size() &&
3870  "number of tile sizes specified must match the specified number of "
3871  "original dimensions to be tiled");
3872  SmallVector<int64_t> staticTileSizes;
3873  SmallVector<Value> dynamicTileSizes;
3874  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3875  build(builder, state, dest.getType(), source, dest,
3876  outerDimsPerm.empty() ? nullptr
3877  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3878  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3879  builder.getDenseI64ArrayAttr(staticTileSizes));
3880 }
3881 
3882 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
3883  Value source,
3884  ArrayRef<OpFoldResult> innerTileSizes,
3885  ArrayRef<int64_t> innerDimsPos,
3886  ArrayRef<int64_t> outerDimsPerm) {
3887  AffineExpr sym0, sym1;
3888  bindSymbols(b.getContext(), sym0, sym1);
3889  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
3890  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
3891  };
3892 
3893  SmallVector<OpFoldResult> mixedSizes;
3894  auto srcType = llvm::cast<RankedTensorType>(source.getType());
3895  for (auto i :
3896  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
3897  if (srcType.isDynamicDim(i))
3898  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
3899  else
3900  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
3901  }
3902  if (!outerDimsPerm.empty()) {
3903  applyPermutationToVector<OpFoldResult>(
3904  mixedSizes, invertPermutationVector(outerDimsPerm));
3905  }
3906 
3907  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
3908  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
3909 
3910  auto elemType = srcType.getElementType();
3911  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3912 }
3913 
3914 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
3915  Value transposedSource,
3916  ArrayRef<int64_t> innerPermutation,
3917  ArrayRef<int64_t> outerPermutation) {
3918  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3919  *this, innerPermutation, outerPermutation);
3920  return b.create<UnPackOp>(loc, transposedSource, getDest(),
3921  metadata.innerDimsPos, metadata.innerTiles,
3922  metadata.outerDimsPerm);
3923 }
3924 
3925 /// pack(unpack(x)) -> x
3926 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
3927  PatternRewriter &rewriter) {
3928  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3929  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3930  return failure();
3931  if (packOp.getPaddingValue() ||
3932  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
3933  !haveSameTiles(packOp, unPackOp))
3934  return failure();
3935  rewriter.replaceOp(unPackOp, packOp.getSource());
3936  return success();
3937 }
3938 
3939 bool UnPackOp::isLikeUnPad() {
3940  RankedTensorType packedTensorType = getSourceType();
3941  return isLikePadUnPad(*this, packedTensorType);
3942 }
3943 //===----------------------------------------------------------------------===//
3944 // Common Canonicalizers and Folders.
3945 //===----------------------------------------------------------------------===//
3946 
3947 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
3948 /// the `tensor.cast` has source that is more static than the consuming op.
3949 ///
3950 /// Example:
3951 /// ```mlir
3952 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3953 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
3954 /// ```
3955 ///
3956 /// folds into:
3957 ///
3958 /// ```mlir
3959 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
3960 /// ```
3961 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
3962 /// can add the pattern to their canonicalizers.
3964  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
3966  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
3967 
3968  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
3969  PatternRewriter &rewriter) const override {
3970  // InsertSliceOp has its own logic about folding tensor.cast ops.
3971  if (isa<InsertSliceOp>(op.getOperation()))
3972  return failure();
3973 
3974  // Exclude DPS ops that are also LoopLike from this interface as they
3975  // might need special handling of attached regions.
3976  if (isa<LoopLikeOpInterface>(op.getOperation()))
3977  return failure();
3978 
3979  // If no operand comes from a tensor::CastOp and can be folded then fail.
3980  bool hasTensorCastOperand =
3981  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
3982  if (llvm::isa<BlockArgument>(opOperand.get()))
3983  return false;
3984  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3985  return castOp && canFoldIntoConsumerOp(castOp);
3986  });
3987  if (!hasTensorCastOperand)
3988  return failure();
3989 
3990  SmallVector<Type, 4> newResultTypes;
3991  newResultTypes.reserve(op->getNumResults());
3992  SmallVector<Value, 4> newOperands;
3993  newOperands.reserve(op->getNumOperands());
3994  for (OpOperand &opOperand : op->getOpOperands()) {
3995  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3996  bool fold = canFoldIntoConsumerOp(tensorCastOp);
3997  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
3998  if (op.isDpsInit(&opOperand) &&
3999  !llvm::isa<MemRefType>(newOperands.back().getType()))
4000  newResultTypes.push_back(newOperands.back().getType());
4001  }
4002 
4003  // Clone op.
4004  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4005  SmallVector<Value, 4> replacements;
4006  replacements.reserve(newOp->getNumResults());
4007  for (auto [oldResult, newResult] :
4008  llvm::zip(op->getResults(), newOp->getResults())) {
4009  if (newResult.getType() != oldResult.getType()) {
4010  replacements.push_back(rewriter.create<tensor::CastOp>(
4011  op->getLoc(), oldResult.getType(), newResult));
4012  } else {
4013  replacements.push_back(newResult);
4014  }
4015  }
4016  rewriter.replaceOp(op, replacements);
4017 
4018  return success();
4019  }
4020 };
4021 
4022 //===----------------------------------------------------------------------===//
4023 // TensorDialect
4024 //===----------------------------------------------------------------------===//
4025 
4026 void TensorDialect::getCanonicalizationPatterns(
4027  RewritePatternSet &results) const {
4029 }
4030 
4031 //===----------------------------------------------------------------------===//
4032 // TableGen'd op method definitions
4033 //===----------------------------------------------------------------------===//
4034 
4035 #define GET_OP_CLASSES
4036 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:3723
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:343
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1043
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3467
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:1849
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3287
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3318
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2546
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3604
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3589
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3302
static void operandsAndShape(TensorType resultType, Operation::operand_range dynamicExtents, SmallVectorImpl< Value > &newOperands, SmallVectorImpl< int64_t > &newShape)
Extract operands and shape from a tensor with dynamic extents.
Definition: TensorOps.cpp:1144
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2272
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2122
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2144
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3348
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1333
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3762
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2233
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:3788
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3363
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3275
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3331
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2542
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:130
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3754
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2292
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1459
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:829
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:143
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
This is a value defined by a result of an operation.
Definition: Value.h:448
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:460
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:248
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:259
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:90
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1276
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:114
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:350
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:314
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2082
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:305
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:275
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2519
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2171
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:118
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:50
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:69
void populateSimplifyTensorPack(RewritePatternSet &patterns)
Patterns to simplify tensor.pack.
Definition: TensorOps.cpp:3269
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:227
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:60
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:104
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:398
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:345
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs)
Returns "success" when any of the elements in ofrs is a constant value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:825
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:28
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:3964
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:3968
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2101
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2102
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2089
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2090
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.