MLIR  19.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <algorithm>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::tensor;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 /// Materialize a single constant operation from a given attribute value with
43 /// the desired resultant type.
45  Attribute value, Type type,
46  Location loc) {
47  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
48  return op;
49  if (complex::ConstantOp::isBuildableWith(value, type))
50  return builder.create<complex::ConstantOp>(loc, type,
51  llvm::cast<ArrayAttr>(value));
52  return nullptr;
53 }
54 
56  int64_t dim) {
57  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
59  if (tensorType.isDynamicDim(dim))
60  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
61 
62  return builder.getIndexAttr(tensorType.getDimSize(dim));
63 }
64 
66  Location loc, Value value) {
67  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
69  for (int64_t i = 0; i < tensorType.getRank(); ++i)
70  result.push_back(getMixedSize(builder, loc, value, i));
71  return result;
72 }
73 
75  OpResult opResult) {
76  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
77  assert(tensorType && "expected tensor type");
78 
79  // If the op has a destination, it implements DestinationStyleOpInterface and
80  // we can query the destination operand from that interface.
81  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
82  if (destOp)
83  return destOp.getTiedOpOperand(opResult)->get();
84 
85  // Otherwise, create a new destination tensor with the same shape.
87  b.setInsertionPoint(opResult.getDefiningOp());
88 
89  // Compute sizes.
90  SmallVector<OpFoldResult> mixedSizes;
91  if (!tensorType.hasStaticShape()) {
92  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
93  ReifiedRankedShapedTypeDims reifiedShapes;
94  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
95  return failure();
96  mixedSizes = reifiedShapes[opResult.getResultNumber()];
97  } else {
98  // Static shape: Take static sizes directly.
99  for (int64_t sz : tensorType.getShape())
100  mixedSizes.push_back(b.getIndexAttr(sz));
101  }
102 
103  // Create empty tensor.
104  Value emptyTensor =
105  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
106  return emptyTensor;
107 }
108 
110  Operation *op,
111  SmallVector<Value> &result) {
112  for (OpResult opResult : op->getResults()) {
113  if (llvm::isa<TensorType>(opResult.getType())) {
114  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
115  if (failed(destination))
116  return failure();
117  result.push_back(*destination);
118  }
119  }
120  return success();
121 }
122 
124  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
125  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
126  return rtp1.getShape() == rtp2.getShape() &&
127  rtp1.getElementType() == rtp2.getElementType();
128  return false;
129  }
130  return tp1 == tp2; // default implementation
131 }
132 
133 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
134 /// rank-extending tensor.insert_slice op.
135 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
136  ArrayRef<OpFoldResult> mixedSizes) {
137  llvm::SmallBitVector droppedDims(mixedSizes.size());
138  int64_t shapePos = 0;
139 
140  for (const auto &size : enumerate(mixedSizes)) {
141  // Rank-reduced dims must have a static unit dimension.
142  bool isStaticUnitSize =
143  size.value().is<Attribute>() &&
144  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
145 
146  if (shapePos == static_cast<int64_t>(reducedShape.size())) {
147  // There are no more dims in the reduced shape. All remaining sizes must
148  // be rank-reduced dims.
149  assert(isStaticUnitSize && "expected unit dim");
150  droppedDims.set(size.index());
151  continue;
152  }
153 
154  // Dim is preserved if the size is not a static 1.
155  if (!isStaticUnitSize) {
156  ++shapePos;
157  continue;
158  }
159 
160  // Dim is preserved if the reduced shape dim is also 1.
161  if (reducedShape[shapePos] == 1) {
162  ++shapePos;
163  continue;
164  }
165 
166  // Otherwise: Dim is dropped.
167  droppedDims.set(size.index());
168  }
169 
170  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
171  "dimension mismatch");
172  return droppedDims;
173 }
174 
175 /// Given a ranked tensor type and a range of values that defines its dynamic
176 /// dimension sizes, turn all dynamic sizes that have a constant value into
177 /// static dimension sizes.
178 static RankedTensorType
179 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
180  SmallVector<Value> &foldedDynamicSizes) {
181  SmallVector<int64_t> staticShape(type.getShape().begin(),
182  type.getShape().end());
183  assert(type.getNumDynamicDims() ==
184  static_cast<int64_t>(dynamicSizes.size()) &&
185  "incorrect number of dynamic sizes");
186 
187  // Compute new static and dynamic sizes.
188  unsigned ctr = 0;
189  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
190  if (type.isDynamicDim(i)) {
191  Value dynamicSize = dynamicSizes[ctr++];
192  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
193  if (cst.has_value()) {
194  // Dynamic size must be non-negative.
195  if (cst.value() < 0) {
196  foldedDynamicSizes.push_back(dynamicSize);
197  continue;
198  }
199  staticShape[i] = *cst;
200  } else {
201  foldedDynamicSizes.push_back(dynamicSize);
202  }
203  }
204  }
205 
206  return RankedTensorType::get(staticShape, type.getElementType(),
207  type.getEncoding());
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // BitcastOp
212 //===----------------------------------------------------------------------===//
213 
214 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
215  if (inputs.size() != 1 || outputs.size() != 1)
216  return false;
217  Type a = inputs.front(), b = outputs.front();
218  auto aT = dyn_cast<TensorType>(a);
219  auto bT = dyn_cast<TensorType>(b);
220  if (!aT || !bT)
221  return false;
222 
223  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
224  return false;
225 
226  return succeeded(verifyCompatibleShape(aT, bT));
227 }
228 
229 namespace {
230 
231 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
232 /// operation.
233 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
235 
236  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237  PatternRewriter &rewriter) const final {
238  auto tensorBitcastOperand =
239  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
240  if (!tensorBitcastOperand)
241  return failure();
242 
243  auto resultType = cast<TensorType>(tensorBitcast.getType());
244  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
245  tensorBitcastOperand.getOperand());
246  return success();
247  }
248 };
249 
250 } // namespace
251 
252 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
253  MLIRContext *context) {
254  results.add<ChainedTensorBitcast>(context);
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // CastOp
259 //===----------------------------------------------------------------------===//
260 
261 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
262  setNameFn(getResult(), "cast");
263 }
264 
265 /// Returns true if `target` is a ranked tensor type that preserves static
266 /// information available in the `source` ranked tensor type.
268  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
269  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
270 
271  // Requires RankedTensorType.
272  if (!sourceType || !targetType)
273  return false;
274 
275  // Requires same elemental type.
276  if (sourceType.getElementType() != targetType.getElementType())
277  return false;
278 
279  // Requires same rank.
280  if (sourceType.getRank() != targetType.getRank())
281  return false;
282 
283  // Requires same encoding.
284  if (sourceType.getEncoding() != targetType.getEncoding())
285  return false;
286 
287  // If cast is towards more static sizes along any dimension, don't fold.
288  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
289  if (!ShapedType::isDynamic(std::get<0>(t)) &&
290  ShapedType::isDynamic(std::get<1>(t)))
291  return false;
292  }
293 
294  return true;
295 }
296 
297 /// Determines whether tensor::CastOp casts to a more dynamic version of the
298 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
299 /// implement canonicalization patterns for ops in different dialects that may
300 /// consume the results of tensor.cast operations. Such foldable tensor.cast
301 /// operations are typically inserted as `slice` ops and are canonicalized,
302 /// to preserve the type compatibility of their uses.
303 ///
304 /// Returns true when all conditions are met:
305 /// 1. source and result are ranked tensors with same element type and rank.
306 /// 2. the tensor type has more static information than the result
307 ///
308 /// Example:
309 /// ```mlir
310 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
311 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
312 /// ```
313 ///
314 /// folds into:
315 ///
316 /// ```mlir
317 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
318 /// ```
320  if (!castOp)
321  return false;
322 
323  // Can fold if the source of cast has at least as much static information as
324  // its results.
325  return preservesStaticInformation(castOp.getType(),
326  castOp.getSource().getType());
327 }
328 
329 /// Determines whether the tensor::CastOp casts to a more static version of the
330 /// source tensor. This is useful to fold into a producing op and implement
331 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
332 /// being from different dialects. Returns true when all conditions are met:
333 /// 1. source and result and ranked tensors with same element type and rank.
334 /// 2. the result type has more static information than the source.
335 ///
336 /// Example:
337 /// ```mlir
338 /// %1 = producer ... : tensor<?x?xf32>
339 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340 /// ```
341 ///
342 /// can be canonicalized to :
343 ///
344 /// ```mlir
345 /// %2 = producer ... : tensor<8x16xf32>
346 /// ```
347 /// Not all ops might be canonicalizable this way, but for those that can be,
348 /// this method provides a check that it is worth doing the canonicalization.
350  if (!castOp)
351  return false;
352  return preservesStaticInformation(castOp.getSource().getType(),
353  castOp.getType());
354 }
355 
356 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
357 /// that can be folded.
359  bool folded = false;
360  for (OpOperand &operand : op->getOpOperands()) {
361  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
362  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
363  operand.set(castOp.getOperand());
364  folded = true;
365  }
366  }
367  return success(folded);
368 }
369 
370 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
371  if (inputs.size() != 1 || outputs.size() != 1)
372  return false;
373  Type a = inputs.front(), b = outputs.front();
374  auto aT = llvm::dyn_cast<TensorType>(a);
375  auto bT = llvm::dyn_cast<TensorType>(b);
376  if (!aT || !bT)
377  return false;
378 
379  if (aT.getElementType() != bT.getElementType())
380  return false;
381 
382  return succeeded(verifyCompatibleShape(aT, bT));
383 }
384 
385 /// Compute a TensorType that has the joined shape knowledge of the two
386 /// given TensorTypes. The element types need to match.
388  assert(one.getElementType() == two.getElementType());
389 
390  if (!one.hasRank())
391  return two;
392  if (!two.hasRank())
393  return one;
394 
395  int64_t rank = one.getRank();
396  if (rank != two.getRank())
397  return {};
398 
400  join.reserve(rank);
401  for (int64_t i = 0; i < rank; ++i) {
402  if (one.isDynamicDim(i)) {
403  join.push_back(two.getDimSize(i));
404  continue;
405  }
406  if (two.isDynamicDim(i)) {
407  join.push_back(one.getDimSize(i));
408  continue;
409  }
410  if (one.getDimSize(i) != two.getDimSize(i))
411  return {};
412  join.push_back(one.getDimSize(i));
413  }
414  return RankedTensorType::get(join, one.getElementType());
415 }
416 
417 namespace {
418 
419 /// Replaces chains of two tensor.cast operations by a single tensor.cast
420 /// operation if doing so does not remove runtime constraints.
421 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
423 
424  LogicalResult matchAndRewrite(CastOp tensorCast,
425  PatternRewriter &rewriter) const final {
426  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 
428  if (!tensorCastOperand)
429  return failure();
430 
431  auto sourceType =
432  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
433  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
434  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
435 
436  // We can remove the intermediate cast if joining all three produces the
437  // same result as just joining the source and result shapes.
438  auto firstJoin =
439  joinShapes(joinShapes(sourceType, intermediateType), resultType);
440 
441  // The join might not exist if the cast sequence would fail at runtime.
442  if (!firstJoin)
443  return failure();
444 
445  // The newJoin always exists if the above join exists, it might just contain
446  // less information. If so, we cannot drop the intermediate cast, as doing
447  // so would remove runtime checks.
448  auto newJoin = joinShapes(sourceType, resultType);
449  if (firstJoin != newJoin)
450  return failure();
451 
452  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
453  tensorCastOperand.getOperand());
454  return success();
455  }
456 };
457 
458 /// Fold tensor.cast into tesor.extract_slice producer.
459 /// Example:
460 /// ```
461 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
462 /// tensor<128x512xf32> to tensor<?x512xf32>
463 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
464 /// ```
465 /// ->
466 /// ```
467 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
468 /// tensor<128x512xf32> to tensor<16x512xf32>
469 /// ```
470 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
472 
473  LogicalResult matchAndRewrite(CastOp tensorCast,
474  PatternRewriter &rewriter) const final {
475  auto extractOperand =
476  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
477 
478  // Cannot fold cast to unranked tensor.
479  auto rankedResultType =
480  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
481  if (!rankedResultType)
482  return failure();
483 
484  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
485  rankedResultType.getShape() ==
486  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
487  .getShape())
488  return failure();
489 
490  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
491  auto dimMask = computeRankReductionMask(
492  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
493  size_t dimIndex = 0;
494  for (size_t i = 0, e = sizes.size(); i < e; i++) {
495  if (dimMask && dimMask->count(i))
496  continue;
497  int64_t dim = rankedResultType.getShape()[dimIndex++];
498  if (ShapedType::isDynamic(dim))
499  continue;
500  sizes[i] = rewriter.getIndexAttr(dim);
501  }
502 
503  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
504  tensorCast, rankedResultType, extractOperand.getSource(),
505  extractOperand.getMixedOffsets(), sizes,
506  extractOperand.getMixedStrides());
507  return success();
508  }
509 };
510 
511 } // namespace
512 
513 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
514  MLIRContext *context) {
515  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // ConcatOp
520 //===----------------------------------------------------------------------===//
521 
522 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
523  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
524  auto tensorTypes =
525  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
526  return llvm::cast<RankedTensorType>(type);
527  }));
528  int64_t concatRank = tensorTypes[0].getRank();
529 
530  // The concatenation dim must be in the range [0, rank).
531  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
532 
533  SmallVector<int64_t> sizes(concatRank);
534  for (int64_t i = 0, e = concatRank; i < e; ++i) {
535  if (i == dim)
536  continue;
537  SaturatedInteger size;
538  for (auto tensorType : tensorTypes)
539  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
540  sizes[i] = size.asInteger();
541  }
542  auto concatSize = SaturatedInteger::wrap(0);
543  for (auto tensorType : tensorTypes)
544  concatSize =
545  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
546  sizes[dim] = concatSize.asInteger();
547  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
548 }
549 
550 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
551  ValueRange inputs) {
552  FailureOr<RankedTensorType> resultType =
553  inferResultType(dim, inputs.getTypes());
554  assert(succeeded(resultType) && "failed to infer concatenation result type");
555  build(builder, result, *resultType, dim, inputs);
556 }
557 
559  if (getInputs().size() < 1)
560  return emitOpError("requires at least one input");
561 
563  for (auto input : getInputs())
564  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 
566  RankedTensorType resultType = getResultType();
567  int64_t resultRank = getRank();
568  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
569  return type.getRank() != resultRank;
570  }))
571  return emitOpError("rank of concatenated inputs must match result rank");
572 
573  Type resultElementType = resultType.getElementType();
574  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
575  return type.getElementType() != resultElementType;
576  }))
577  return emitOpError("inputs and result element type must match");
578 
579  int64_t dim = getDim();
580  if (dim >= resultRank)
581  return emitOpError("concatenation dim must be less than the tensor rank");
582 
583  SmallVector<int64_t> sizes(resultRank);
584  for (int64_t i = 0, e = resultRank; i < e; ++i) {
585  if (i == dim)
586  continue;
587  SaturatedInteger size;
588  for (auto tensorType : inputTypes) {
589  FailureOr<SaturatedInteger> maybeSize =
590  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
591  if (failed(maybeSize))
592  return emitOpError("static concatenation size mismatch along ")
593  << "non-concatenated dimension " << i;
594  size = *maybeSize;
595  }
596  sizes[i] = size.asInteger();
597  }
598  auto concatSize = SaturatedInteger::wrap(0);
599  for (auto tensorType : inputTypes)
600  concatSize =
601  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
602  sizes[dim] = concatSize.asInteger();
603  auto inferredResultType =
604  RankedTensorType::get(sizes, inputTypes[0].getElementType());
605 
606  for (auto [inferredSize, actualSize] :
607  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
608  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
609  ShapedType::isDynamic(actualSize);
610  if (!hasDynamic && inferredSize != actualSize)
611  return emitOpError("result type ")
612  << resultType << "does not match inferred shape "
613  << inferredResultType << " static sizes";
614  }
615 
616  return success();
617 }
618 
621  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
622  ValueRange inputs = getInputs();
623  int64_t dim = getDim();
624  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
625 
626  Value init = inputs[0];
627  int64_t rank = getType().getRank();
628 
629  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
630 
631  // Pre-populate the result sizes with as much static information as possible
632  // from the given result type, as well as the inferred result type, otherwise
633  // use the dim sizes from the first input.
634  for (int64_t i = 0; i < rank; ++i) {
635  if (i == dim)
636  continue;
637  if (!getType().isDynamicDim(i)) {
638  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
639  } else if (!inferredResultType.isDynamicDim(i)) {
640  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
641  builder, getLoc(),
642  builder.getIndexAttr(inferredResultType.getDimSize(i)));
643  } else {
644  reifiedReturnShapes[0][i] =
645  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
646  }
647  }
648 
649  if (getType().isDynamicDim(dim)) {
650  // Take the sum of the input sizes along the concatenated dim.
651  AffineExpr sum = builder.getAffineDimExpr(0);
652  SmallVector<OpFoldResult> sizes = {
653  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
654  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
655  sum = sum + builder.getAffineDimExpr(idx + 1);
656  sizes.push_back(
657  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
658  }
659  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
660  builder, getLoc(),
661  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
662  } else {
663  // If the result shape is static along the concatenated dim, use the static
664  // shape.
665  reifiedReturnShapes[0][dim] =
666  builder.getIndexAttr(getType().getDimSize(dim));
667  }
668  return success();
669 }
670 
671 void ConcatOp::getAsmResultNames(
672  function_ref<void(Value, StringRef)> setNameFn) {
673  setNameFn(getResult(), "concat");
674 }
675 
676 OpFoldResult ConcatOp::fold(FoldAdaptor) {
677  ValueRange inputs = getInputs();
678  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
679  return inputs[0];
680  return {};
681 }
682 
683 namespace {
684 /// Fold a concat op with a single input to a cast.
685 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
687 
688  LogicalResult matchAndRewrite(ConcatOp concatOp,
689  PatternRewriter &rewriter) const override {
690  if (concatOp.getInputs().size() != 1)
691  return failure();
692  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
693  concatOp.getInputs()[0]);
694  return success();
695  }
696 };
697 } // namespace
698 
699 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
700  MLIRContext *context) {
701  results.add<SingleInputConcatOp>(context);
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // DimOp
706 //===----------------------------------------------------------------------===//
707 
708 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
709  setNameFn(getResult(), "dim");
710 }
711 
712 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
713  int64_t index) {
714  auto loc = result.location;
715  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
716  build(builder, result, source, indexValue);
717 }
718 
719 std::optional<int64_t> DimOp::getConstantIndex() {
720  return getConstantIntValue(getIndex());
721 }
722 
723 Speculation::Speculatability DimOp::getSpeculatability() {
724  auto constantIndex = getConstantIndex();
725  if (!constantIndex)
727 
728  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
729  if (!rankedSourceType)
731 
732  if (rankedSourceType.getRank() <= constantIndex)
734 
736 }
737 
738 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
739  // All forms of folding require a known index.
740  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
741  if (!index)
742  return {};
743 
744  // Folding for unranked types (UnrankedTensorType) is not supported.
745  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
746  if (!tensorType)
747  return {};
748 
749  // Out of bound indices produce undefined behavior but are still valid IR.
750  // Don't choke on them.
751  int64_t indexVal = index.getInt();
752  if (indexVal < 0 || indexVal >= tensorType.getRank())
753  return {};
754 
755  // Fold if the shape extent along the given index is known.
756  if (!tensorType.isDynamicDim(index.getInt())) {
757  Builder builder(getContext());
758  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
759  }
760 
761  Operation *definingOp = getSource().getDefiningOp();
762 
763  // Fold dim to the operand of tensor.generate.
764  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
765  auto resultType =
766  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
767  // The case where the type encodes the size of the dimension is handled
768  // above.
769  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
770 
771  // Find the operand of the fromElements that corresponds to this index.
772  auto dynExtents = fromElements.getDynamicExtents().begin();
773  for (auto dim : resultType.getShape().take_front(index.getInt()))
774  if (ShapedType::isDynamic(dim))
775  dynExtents++;
776 
777  return Value{*dynExtents};
778  }
779 
780  // The size at the given index is now known to be a dynamic size.
781  unsigned unsignedIndex = index.getValue().getZExtValue();
782 
783  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
784  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
785  // `resolve-shaped-type-result-dims` pass.
786  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
787  sliceOp.isDynamicSize(unsignedIndex)) {
788  return {sliceOp.getDynamicSize(unsignedIndex)};
789  }
790  }
791 
792  // dim(cast) -> dim
793  if (succeeded(foldTensorCast(*this)))
794  return getResult();
795 
796  return {};
797 }
798 
799 namespace {
800 /// Fold dim of a cast into the dim of the source of the tensor cast.
801 struct DimOfCastOp : public OpRewritePattern<DimOp> {
803 
804  LogicalResult matchAndRewrite(DimOp dimOp,
805  PatternRewriter &rewriter) const override {
806  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
807  if (!castOp)
808  return failure();
809  Value newSource = castOp.getOperand();
810  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
811  return success();
812  }
813 };
814 
815 /// Fold dim of a destination passing style op into the dim of the corresponding
816 /// init.
817 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
819 
820  LogicalResult matchAndRewrite(DimOp dimOp,
821  PatternRewriter &rewriter) const override {
822  auto source = dimOp.getSource();
823  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
824  if (!destOp)
825  return failure();
826 
827  auto resultIndex = cast<OpResult>(source).getResultNumber();
828  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
829 
830  rewriter.modifyOpInPlace(
831  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
832  return success();
833  }
834 };
835 
836 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
837 /// operand.
838 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
840 
841  LogicalResult matchAndRewrite(DimOp dim,
842  PatternRewriter &rewriter) const override {
843  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
844 
845  if (!reshape)
846  return failure();
847 
848  // Since tensors are immutable we don't need to worry about where to place
849  // the extract call
850  rewriter.setInsertionPointAfter(dim);
851  Location loc = dim.getLoc();
852  Value extract =
853  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
854  if (extract.getType() != dim.getType())
855  extract =
856  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
857  rewriter.replaceOp(dim, extract);
858  return success();
859  }
860 };
861 } // namespace
862 
863 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
864  MLIRContext *context) {
865  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // EmptyOp
870 //===----------------------------------------------------------------------===//
871 
872 void EmptyOp::build(OpBuilder &builder, OperationState &result,
873  ArrayRef<int64_t> staticShape, Type elementType,
874  Attribute encoding) {
875  assert(all_of(staticShape,
876  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
877  "expected only static sizes");
878  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
879 }
880 
881 void EmptyOp::build(OpBuilder &builder, OperationState &result,
882  ArrayRef<int64_t> staticShape, Type elementType,
883  ValueRange dynamicSizes, Attribute encoding) {
884  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
885  build(builder, result, tensorType, dynamicSizes);
886 }
887 
888 void EmptyOp::build(OpBuilder &builder, OperationState &result,
889  ArrayRef<OpFoldResult> sizes, Type elementType,
890  Attribute encoding) {
891  SmallVector<int64_t> staticShape;
892  SmallVector<Value> dynamicSizes;
893  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
894  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
895 }
896 
898  if (getType().getNumDynamicDims() !=
899  static_cast<int64_t>(getDynamicSizes().size()))
900  return emitOpError("incorrect number of dynamic sizes, has ")
901  << getDynamicSizes().size() << ", expected "
902  << getType().getNumDynamicDims();
903  return success();
904 }
905 
908  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
909  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
910  unsigned ctr = 0;
911  for (int64_t i = 0; i < getType().getRank(); ++i) {
912  if (getType().isDynamicDim(i)) {
913  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
914  } else {
915  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
916  }
917  }
918  return success();
919 }
920 
921 Value EmptyOp::getDynamicSize(unsigned idx) {
922  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
923  unsigned ctr = 0;
924  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
925  if (getType().isDynamicDim(i))
926  ++ctr;
927  return getDynamicSizes()[ctr];
928 }
929 
932  unsigned ctr = 0;
933  OpBuilder b(getContext());
934  for (int64_t i = 0; i < getType().getRank(); ++i) {
935  if (getType().isDynamicDim(i)) {
936  result.push_back(getDynamicSizes()[ctr++]);
937  } else {
938  result.push_back(b.getIndexAttr(getType().getShape()[i]));
939  }
940  }
941  return result;
942 }
943 
944 namespace {
945 /// Change the type of the result of a `tensor.empty` by making the result
946 /// type statically sized along dimensions that in the original operation were
947 /// defined as dynamic, but the size was defined using a `constant` op. For
948 /// example
949 ///
950 /// %c5 = arith.constant 5: index
951 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
952 ///
953 /// to
954 ///
955 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
956 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
958 
959  LogicalResult matchAndRewrite(EmptyOp op,
960  PatternRewriter &rewriter) const override {
961  SmallVector<Value> foldedDynamicSizes;
962  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
963  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
964 
965  // Stop here if no dynamic size was promoted to static.
966  if (foldedTensorType == op.getType())
967  return failure();
968 
969  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
970  foldedDynamicSizes);
971  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
972  return success();
973  }
974 };
975 
976 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
978 
979  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
980  PatternRewriter &rewriter) const override {
981  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
982  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
983  if (!emptyTensorOp || !maybeConstantIndex)
984  return failure();
985  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
986  return failure();
987  rewriter.replaceOp(dimOp,
988  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
989  return success();
990  }
991 };
992 
993 /// Canonicalize
994 ///
995 /// ```mlir
996 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
997 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
998 /// ```
999 ///
1000 /// into
1001 ///
1002 /// ```mlir
1003 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1004 /// ```
1005 ///
1006 /// This assumes the input program is correct in terms of its shape. So it is
1007 /// safe to assume that `%d0` is in fact 4.
1008 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1010 
1011  LogicalResult matchAndRewrite(CastOp castOp,
1012  PatternRewriter &rewriter) const override {
1013  if (!canFoldIntoProducerOp(castOp))
1014  return failure();
1015  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1016  if (!producer)
1017  return failure();
1018 
1019  auto resultType =
1020  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1021  ArrayRef<int64_t> resultShape = resultType.getShape();
1022  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1023  SmallVector<OpFoldResult> newMixedSizes;
1024  newMixedSizes.reserve(currMixedSizes.size());
1025  assert(resultShape.size() == currMixedSizes.size() &&
1026  "mismatch in result shape and sizes of empty op");
1027  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1028  int64_t newDim = std::get<0>(it);
1029  OpFoldResult currDim = std::get<1>(it);
1030  // Case 1: The empty tensor dim is static. Check that the tensor cast
1031  // result dim matches.
1032  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1033  if (ShapedType::isDynamic(newDim) ||
1034  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1035  // Something is off, the cast result shape cannot be more dynamic
1036  // than the empty tensor result shape (enforced by
1037  // `canFoldIntoProducer`). Abort for now.
1038  return rewriter.notifyMatchFailure(
1039  producer, "mismatch in static value of shape of empty tensor "
1040  "result and cast result");
1041  }
1042  newMixedSizes.push_back(attr);
1043  continue;
1044  }
1045 
1046  // Case 2 : The tensor cast shape is static, but empty tensor result
1047  // shape is dynamic.
1048  if (!ShapedType::isDynamic(newDim)) {
1049  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1050  continue;
1051  }
1052 
1053  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1054  // shape is dynamic. Use the dynamic value from the empty tensor op.
1055  newMixedSizes.push_back(currDim);
1056  }
1057 
1058  // TODO: Do not drop tensor encoding.
1059  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1060  resultType.getElementType());
1061  return success();
1062  }
1063 };
1064 
1065 } // namespace
1066 
1067 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068  MLIRContext *context) {
1069  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1070  ReplaceEmptyTensorStaticShapeDims>(context);
1071 }
1072 
1073 /// Try to remove a tensor operation if it would only reshape a constant.
1074 /// Removes the op and replaces the constant with a new constant of the result
1075 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1076 /// splat value matches the value in the attribute.
1077 static OpFoldResult
1079  std::optional<Attribute> cst = std::nullopt) {
1080  if (source && source.isSplat() && result.hasStaticShape() &&
1081  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1082  return source.resizeSplat(result);
1083 
1084  return {};
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // ExtractOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 namespace {
1092 
1093 /// Canonicalizes the pattern of the form
1094 ///
1095 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1096 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1097 ///
1098 /// to
1099 ///
1100 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1101 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1103 
1104  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1105  PatternRewriter &rewriter) const final {
1106  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1107  if (!tensorCast)
1108  return failure();
1109  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1110  return failure();
1111  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1112  extract, tensorCast.getSource(), extract.getIndices());
1113  return success();
1114  }
1115 };
1116 
1117 } // namespace
1118 
1119 void ExtractOp::getAsmResultNames(
1120  function_ref<void(Value, StringRef)> setNameFn) {
1121  setNameFn(getResult(), "extracted");
1122 }
1123 
1125  // Verify the # indices match if we have a ranked type.
1126  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1127  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1128  return emitOpError("incorrect number of indices for extract_element");
1129  return success();
1130 }
1131 
1132 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1133  // If this is a splat elements attribute, simply return the value. All of
1134  // the elements of a splat attribute are the same.
1135  if (Attribute tensor = adaptor.getTensor())
1136  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1137  return splatTensor.getSplatValue<Attribute>();
1138 
1139  // Collect the constant indices into the tensor.
1140  SmallVector<uint64_t, 8> indices;
1141  for (Attribute indice : adaptor.getIndices()) {
1142  if (!indice || !llvm::isa<IntegerAttr>(indice))
1143  return {};
1144  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1145  }
1146 
1147  // Fold extract(from_elements(...)).
1148  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1149  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1150  auto rank = tensorType.getRank();
1151  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1152  "rank mismatch");
1153  int flatIndex = 0;
1154  int stride = 1;
1155  for (int i = rank - 1; i >= 0; --i) {
1156  flatIndex += indices[i] * stride;
1157  stride *= tensorType.getDimSize(i);
1158  }
1159  // Prevent out of bounds accesses. This can happen in invalid code that
1160  // will never execute.
1161  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1162  flatIndex < 0)
1163  return {};
1164  return fromElementsOp.getElements()[flatIndex];
1165  }
1166 
1167  // If this is an elements attribute, query the value at the given indices.
1168  if (Attribute tensor = adaptor.getTensor()) {
1169  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1170  if (elementsAttr && elementsAttr.isValidIndex(indices))
1171  return elementsAttr.getValues<Attribute>()[indices];
1172  }
1173 
1174  return {};
1175 }
1176 
1177 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1178  MLIRContext *context) {
1179  results.add<ExtractFromTensorCast>(context);
1180 }
1181 
1182 //===----------------------------------------------------------------------===//
1183 // FromElementsOp
1184 //===----------------------------------------------------------------------===//
1185 
1186 void FromElementsOp::getAsmResultNames(
1187  function_ref<void(Value, StringRef)> setNameFn) {
1188  setNameFn(getResult(), "from_elements");
1189 }
1190 
1191 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1192  ValueRange elements) {
1193  assert(!elements.empty() && "expected at least one element");
1194  Type resultType = RankedTensorType::get(
1195  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1196  build(builder, result, resultType, elements);
1197 }
1198 
1199 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1200  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1201  return DenseElementsAttr::get(getType(), adaptor.getElements());
1202  return {};
1203 }
1204 
1205 namespace {
1206 
1207 // Pushes the index_casts that occur before extractions to after the extract.
1208 // This minimizes type conversion in some cases and enables the extract
1209 // canonicalizer. This changes:
1210 //
1211 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1212 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1213 //
1214 // to the following:
1215 //
1216 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1217 // %cast = arith.index_cast %extract : i32 to index
1218 //
1219 // to just %element.
1220 //
1221 // Consider expanding this to a template and handle all tensor cast
1222 // operations.
1223 struct ExtractElementFromIndexCast
1224  : public OpRewritePattern<tensor::ExtractOp> {
1226 
1227  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1228  PatternRewriter &rewriter) const final {
1229  Location loc = extract.getLoc();
1230  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1231  if (!indexCast)
1232  return failure();
1233 
1234  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1235 
1236  auto newExtract = rewriter.create<tensor::ExtractOp>(
1237  loc, elementTy, indexCast.getIn(), extract.getIndices());
1238 
1239  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1240  newExtract);
1241 
1242  return success();
1243  }
1244 };
1245 
1246 } // namespace
1247 
1248 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249  MLIRContext *context) {
1250  results.add<ExtractElementFromIndexCast>(context);
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // GatherOp
1255 //===----------------------------------------------------------------------===//
1256 
1257 void GatherOp::getAsmResultNames(
1258  function_ref<void(Value, StringRef)> setNameFn) {
1259  setNameFn(getResult(), "gather");
1260 }
1261 
1262 /// Return the inferred result type for a gatherOp where:
1263 /// - sourceType is the type of the source tensor gathered from
1264 /// - indicesType is the type of the indices used to gather
1265 /// - gatherDims are the dims along which the gather occurs.
1266 /// Return a full rank or ranked-reduced variant of the type depending on
1267 /// the value of rankReduced.
1268 ///
1269 /// The leading dimensions of the index tensor give the result tensor its
1270 /// leading dimensions.
1271 /// The trailing dimensions of the result tensor are obtained from the source
1272 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1273 /// rankedReduced is false), or skipping them (otherwise).
1274 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1275  RankedTensorType indicesType,
1276  ArrayRef<int64_t> gatherDims,
1277  bool rankReduced) {
1278  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1279  resultShape.reserve(resultShape.size() + sourceType.getRank());
1280  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1281  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1282  if (!rankReduced)
1283  resultShape.push_back(1);
1284  continue;
1285  }
1286  resultShape.push_back(sourceType.getDimSize(idx));
1287  }
1288  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1289 }
1290 
1291 static LogicalResult
1293  StringRef gatherOrScatter, StringRef sourceOrDest) {
1294  if (dims.empty())
1295  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1296 
1297  int64_t numGatherDims = dims.size();
1298  if (numGatherDims > rank)
1299  return op->emitOpError(gatherOrScatter)
1300  << "_dims overflow " << sourceOrDest << " rank";
1301  for (int64_t val : dims) {
1302  if (val < 0)
1303  return op->emitOpError(gatherOrScatter)
1304  << "_dims value must be non-negative";
1305  if (val >= rank)
1306  return op->emitOpError(gatherOrScatter)
1307  << "_dims value must be smaller than " << sourceOrDest << " rank";
1308  }
1309  for (int64_t i = 1; i < numGatherDims; ++i) {
1310  if (dims[i - 1] >= dims[i])
1311  return op->emitOpError(gatherOrScatter)
1312  << "_dims values must be strictly increasing";
1313  }
1314  return success();
1315 }
1316 
1318  int64_t sourceRank = getSourceType().getRank();
1319  ArrayRef<int64_t> gatherDims = getGatherDims();
1320  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1321  "gather", "source")))
1322  return failure();
1323 
1324  RankedTensorType expectedResultType = GatherOp::inferResultType(
1325  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1326  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1327  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1328  if (getResultType() != expectedResultType &&
1329  getResultType() != expectedRankReducedResultType) {
1330  return emitOpError("result type "
1331  "mismatch: "
1332  "expected ")
1333  << expectedResultType << " or its rank-reduced variant "
1334  << expectedRankReducedResultType << " (got: " << getResultType()
1335  << ")";
1336  }
1337 
1338  return success();
1339 }
1340 
1341 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1342  if (OpFoldResult reshapedSource = reshapeConstantSource(
1343  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1344  getResult().getType()))
1345  return reshapedSource;
1346  return {};
1347 }
1348 
1349 //===----------------------------------------------------------------------===//
1350 // InsertOp
1351 //===----------------------------------------------------------------------===//
1352 
1353 void InsertOp::getAsmResultNames(
1354  function_ref<void(Value, StringRef)> setNameFn) {
1355  setNameFn(getResult(), "inserted");
1356 }
1357 
1359  // Verify the # indices match if we have a ranked type.
1360  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1361  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1362  return emitOpError("incorrect number of indices");
1363  return success();
1364 }
1365 
1366 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1367  Attribute scalar = adaptor.getScalar();
1368  Attribute dest = adaptor.getDest();
1369  if (scalar && dest)
1370  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1371  if (scalar == splatDest.getSplatValue<Attribute>())
1372  return dest;
1373  return {};
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // GenerateOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 void GenerateOp::getAsmResultNames(
1381  function_ref<void(Value, StringRef)> setNameFn) {
1382  setNameFn(getResult(), "generated");
1383 }
1384 
1386  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1387  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1388  int idx = 0;
1389  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1390  if (getType().isDynamicDim(dim)) {
1391  reifiedReturnShapes[0][dim] = getOperand(idx++);
1392  } else {
1393  reifiedReturnShapes[0][dim] =
1394  builder.getIndexAttr(getType().getDimSize(dim));
1395  }
1396  }
1397  return success();
1398 }
1399 
1401  // Ensure that the tensor type has as many dynamic dimensions as are
1402  // specified by the operands.
1403  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1404  if (getNumOperands() != resultType.getNumDynamicDims())
1405  return emitError("must have as many index operands as dynamic extents "
1406  "in the result type");
1407  return success();
1408 }
1409 
1410 LogicalResult GenerateOp::verifyRegions() {
1411  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1412  // Ensure that region arguments span the index space.
1413  if (!llvm::all_of(getBody().getArgumentTypes(),
1414  [](Type ty) { return ty.isIndex(); }))
1415  return emitError("all body arguments must be index");
1416  if (getBody().getNumArguments() != resultTy.getRank())
1417  return emitError("must have one body argument per input dimension");
1418 
1419  // Ensure that the region yields an element of the right type.
1420  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1421 
1422  if (yieldOp.getValue().getType() != resultTy.getElementType())
1423  return emitOpError(
1424  "body must be terminated with a `yield` operation of the tensor "
1425  "element type");
1426 
1427  return success();
1428 }
1429 
1430 void GenerateOp::build(
1431  OpBuilder &b, OperationState &result, Type resultTy,
1432  ValueRange dynamicExtents,
1433  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1434  build(b, result, resultTy, dynamicExtents);
1435 
1436  // Build and populate body.
1437  OpBuilder::InsertionGuard guard(b);
1438  Region *bodyRegion = result.regions.front().get();
1439  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1440  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1441  SmallVector<Location, 2> argumentLocs(rank, result.location);
1442  Block *bodyBlock =
1443  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1444  bodyBuilder(b, result.location, bodyBlock->getArguments());
1445 }
1446 
1447 namespace {
1448 
1449 /// Canonicalizes tensor.generate operations with a constant
1450 /// operand into the equivalent operation with the operand expressed in the
1451 /// result type, instead. We also insert a type cast to make sure that the
1452 /// resulting IR is still well-typed.
1453 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1455 
1456  LogicalResult matchAndRewrite(GenerateOp generateOp,
1457  PatternRewriter &rewriter) const final {
1458  SmallVector<Value> foldedDynamicSizes;
1459  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1460  generateOp.getType(), generateOp.getDynamicExtents(),
1461  foldedDynamicSizes);
1462 
1463  // Stop here if no dynamic size was promoted to static.
1464  if (foldedTensorType == generateOp.getType())
1465  return failure();
1466 
1467  auto loc = generateOp.getLoc();
1468  auto newOp =
1469  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1470  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1471  newOp.getBody().begin());
1472  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1473  generateOp.getType(), newOp);
1474  return success();
1475  }
1476 };
1477 
1478 /// Canonicalizes the pattern of the form
1479 ///
1480 /// %tensor = tensor.generate %x {
1481 /// ^bb0(%arg0: index):
1482 /// <computation>
1483 /// yield %1 : index
1484 /// } : tensor<?xindex>
1485 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1486 ///
1487 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1488 /// tensor.generate operation has no side-effects.
1489 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1491 
1492  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1493  PatternRewriter &rewriter) const final {
1494  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1495  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1496  return failure();
1497 
1498  IRMapping mapping;
1499  Block *body = &tensorFromElements.getBody().front();
1500  mapping.map(body->getArguments(), extract.getIndices());
1501  for (auto &op : body->without_terminator())
1502  rewriter.clone(op, mapping);
1503 
1504  auto yield = cast<YieldOp>(body->getTerminator());
1505 
1506  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1507  return success();
1508  }
1509 };
1510 
1511 } // namespace
1512 
1513 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1514  MLIRContext *context) {
1515  // TODO: Move extract pattern to tensor::ExtractOp.
1516  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // RankOp
1521 //===----------------------------------------------------------------------===//
1522 
1523 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1524  setNameFn(getResult(), "rank");
1525 }
1526 
1527 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1528  // Constant fold rank when the rank of the operand is known.
1529  auto type = getOperand().getType();
1530  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1531  if (shapedType && shapedType.hasRank())
1532  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1533  return IntegerAttr();
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // ReshapeOp
1538 //===----------------------------------------------------------------------===//
1539 
1540 void ReshapeOp::getAsmResultNames(
1541  function_ref<void(Value, StringRef)> setNameFn) {
1542  setNameFn(getResult(), "reshape");
1543 }
1544 
1545 static int64_t getNumElements(ShapedType type) {
1546  int64_t numElements = 1;
1547  for (auto dim : type.getShape())
1548  numElements *= dim;
1549  return numElements;
1550 }
1551 
1553  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1554  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1555 
1556  if (operandType.getElementType() != resultType.getElementType())
1557  return emitOpError("element types of source and destination tensor "
1558  "types should be the same");
1559 
1560  int64_t shapeSize =
1561  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1562  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1563  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1564 
1565  if (resultRankedType) {
1566  if (operandRankedType && resultRankedType.hasStaticShape() &&
1567  operandRankedType.hasStaticShape()) {
1568  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1569  return emitOpError("source and destination tensor should have the "
1570  "same number of elements");
1571  }
1572  if (ShapedType::isDynamic(shapeSize))
1573  return emitOpError("cannot use shape operand with dynamic length to "
1574  "reshape to statically-ranked tensor type");
1575  if (shapeSize != resultRankedType.getRank())
1576  return emitOpError(
1577  "length of shape operand differs from the result's tensor rank");
1578  }
1579  return success();
1580 }
1581 
1582 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1583  if (OpFoldResult reshapedSource = reshapeConstantSource(
1584  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1585  getResult().getType()))
1586  return reshapedSource;
1587 
1588  auto source = getSource();
1589  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1590  auto resultTy = dyn_cast<RankedTensorType>(getType());
1591 
1592  if (!sourceTy || !resultTy || sourceTy != resultTy)
1593  return {};
1594 
1595  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1596  auto elements = fromElements.getElements();
1597  bool dynamicNoop =
1598  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1599  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1600  auto element = elements[id];
1601 
1602  if (auto cst = getConstantIntValue(element)) {
1603  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1604  continue;
1605  }
1606 
1607  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1608  dynamicNoop &= dimOp.getSource() == source;
1609 
1610  APSInt dim;
1611  auto cst = getConstantIntValue(dimOp.getIndex());
1612  dynamicNoop &=
1613  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1614  continue;
1615  }
1616 
1617  dynamicNoop = false;
1618  break;
1619  }
1620 
1621  if (dynamicNoop)
1622  return source;
1623  }
1624 
1625  return {};
1626 }
1627 
1628 //===----------------------------------------------------------------------===//
1629 // Reassociative reshape ops
1630 //===----------------------------------------------------------------------===//
1631 
1632 void CollapseShapeOp::getAsmResultNames(
1633  function_ref<void(Value, StringRef)> setNameFn) {
1634  setNameFn(getResult(), "collapsed");
1635 }
1636 
1637 void ExpandShapeOp::getAsmResultNames(
1638  function_ref<void(Value, StringRef)> setNameFn) {
1639  setNameFn(getResult(), "expanded");
1640 }
1641 
1642 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1643  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1644  "invalid resultDim");
1645  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1646  if (llvm::is_contained(it.value(), resultDim))
1647  return it.index();
1648  llvm_unreachable("could not find reassociation group");
1649 }
1650 
1652 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1653  RankedTensorType expandedType,
1654  ArrayRef<ReassociationIndices> reassociation,
1655  ArrayRef<OpFoldResult> inputShape) {
1656  std::optional<SmallVector<OpFoldResult>> outputShape =
1657  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1658  inputShape);
1659  if (!outputShape)
1660  return failure();
1661  return *outputShape;
1662 }
1663 
1664 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1665  Type resultType, Value src,
1666  ArrayRef<ReassociationIndices> reassociation,
1667  ArrayRef<OpFoldResult> outputShape) {
1668  auto [staticOutputShape, dynamicOutputShape] =
1670  build(builder, result, cast<RankedTensorType>(resultType), src,
1671  getReassociationIndicesAttribute(builder, reassociation),
1672  dynamicOutputShape, staticOutputShape);
1673 }
1674 
1675 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1676  Type resultType, Value src,
1677  ArrayRef<ReassociationIndices> reassociation) {
1678  SmallVector<OpFoldResult> inputShape =
1679  getMixedSizes(builder, result.location, src);
1680  auto tensorResultTy = cast<RankedTensorType>(resultType);
1681  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1682  builder, result.location, tensorResultTy, reassociation, inputShape);
1683  SmallVector<OpFoldResult> outputShapeOrEmpty;
1684  if (succeeded(outputShape)) {
1685  outputShapeOrEmpty = *outputShape;
1686  }
1687  build(builder, result, tensorResultTy, src, reassociation,
1688  outputShapeOrEmpty);
1689 }
1690 
1691 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1692  return getSymbolLessAffineMaps(getReassociationExprs());
1693 }
1694 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1696  getReassociationIndices());
1697 }
1698 
1699 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1700  return getSymbolLessAffineMaps(getReassociationExprs());
1701 }
1702 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1704  getReassociationIndices());
1705 }
1706 
1707 RankedTensorType CollapseShapeOp::inferCollapsedType(
1708  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1709  return inferCollapsedType(
1711  type.getContext(), reassociation)));
1712 }
1713 
1714 /// Compute the RankedTensorType obtained by applying `reassociation` to
1715 /// `type`.
1716 RankedTensorType
1717 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1718  ArrayRef<AffineMap> reassociation) {
1719  auto shape = type.getShape();
1720  SmallVector<int64_t, 4> newShape;
1721  newShape.reserve(reassociation.size());
1722 
1723  // Use the fact that reassociation is valid to simplify the logic: only use
1724  // each map's rank.
1725  assert(isReassociationValid(reassociation) && "invalid reassociation");
1726  unsigned currentDim = 0;
1727  for (AffineMap m : reassociation) {
1728  unsigned dim = m.getNumResults();
1729  auto band = shape.slice(currentDim, dim);
1730  int64_t size = 1;
1731  if (llvm::is_contained(band, ShapedType::kDynamic))
1732  size = ShapedType::kDynamic;
1733  else
1734  for (unsigned d = 0; d < dim; ++d)
1735  size *= shape[currentDim + d];
1736  newShape.push_back(size);
1737  currentDim += dim;
1738  }
1739 
1740  return RankedTensorType::get(newShape, type.getElementType());
1741 }
1742 
1743 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1744  ArrayRef<ReassociationIndices> reassociation,
1745  ArrayRef<NamedAttribute> attrs) {
1746  auto resultType = inferCollapsedType(
1747  llvm::cast<RankedTensorType>(src.getType()),
1749  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1750  result.addAttribute(getReassociationAttrStrName(),
1751  getReassociationIndicesAttribute(b, reassociation));
1752  build(b, result, resultType, src, attrs);
1753 }
1754 
1755 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1756  TensorReshapeOp, ExpandShapeOp>::value>
1757 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1758  RankedTensorType expandedType,
1759  RankedTensorType collapsedType) {
1760  if (failed(
1761  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1762  return failure();
1763 
1764  auto maps = op.getReassociationMaps();
1765  RankedTensorType expectedType =
1766  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1767  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1768  return op.emitOpError("expected collapsed type to be ")
1769  << expectedType << ", but got " << collapsedType;
1770  return success();
1771 }
1772 
1774  auto srcType = getSrcType();
1775  auto resultType = getResultType();
1776 
1777  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1778  return emitOpError("expected number of static shape dims to be equal to "
1779  "the output rank (")
1780  << resultType.getRank() << ") but found "
1781  << getStaticOutputShape().size() << " inputs instead";
1782 
1783  if ((int64_t)getOutputShape().size() !=
1784  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1785  return emitOpError("mismatch in dynamic dims in output_shape and "
1786  "static_output_shape: static_output_shape has ")
1787  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1788  << " dynamic dims while output_shape has " << getOutputShape().size()
1789  << " values";
1790 
1791  return verifyTensorReshapeOp(*this, resultType, srcType);
1792 }
1793 
1795  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1796 }
1797 
1798 namespace {
1799 /// Reshape of a splat constant can be replaced with a constant of the result
1800 /// type.
1801 template <typename TensorReshapeOp>
1802 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1804  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1805  PatternRewriter &rewriter) const override {
1806  DenseElementsAttr attr;
1807  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1808  return failure();
1809  if (!attr || !attr.isSplat())
1810  return failure();
1812  reshapeOp.getResultType(), attr.getRawData());
1813  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1814  return success();
1815  }
1816 };
1817 
1818 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1819 template <typename TensorReshapeOp>
1820 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1821 public:
1823 
1824  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1825  PatternRewriter &rewriter) const override {
1826  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1827  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1828  return failure();
1829 
1830  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1831  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1832  return success();
1833  }
1834 };
1835 
1836 /// Reshape of a FromElements can be replaced with a FromElements of the
1837 /// result type
1838 template <typename TensorReshapeOp>
1839 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1841  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1842  PatternRewriter &rewriter) const override {
1843  auto fromElements =
1844  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1845  if (!fromElements)
1846  return failure();
1847 
1848  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1849 
1850  if (!shapedTy.hasStaticShape())
1851  return failure();
1852 
1853  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1854  fromElements.getElements());
1855  return success();
1856  }
1857 };
1858 
1859 // Fold CastOp into CollapseShapeOp when adding static information.
1860 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1862 
1863  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1864  PatternRewriter &rewriter) const override {
1865  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1866  if (!tensor::canFoldIntoConsumerOp(castOp))
1867  return failure();
1868 
1869  RankedTensorType srcType =
1870  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1871  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1872  srcType, collapseShapeOp.getReassociationMaps());
1873 
1874  if (newResultType == collapseShapeOp.getResultType()) {
1875  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1876  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1877  });
1878  } else {
1879  auto newOp = rewriter.create<CollapseShapeOp>(
1880  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1881  collapseShapeOp.getReassociation());
1882  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1883  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1884  }
1885  return success();
1886  }
1887 };
1888 
1889 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1891 
1892  LogicalResult matchAndRewrite(DimOp dimOp,
1893  PatternRewriter &rewriter) const override {
1894  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1895  if (!expandShapeOp)
1896  return failure();
1897 
1898  // Only constant dimension values are supported.
1899  std::optional<int64_t> dim = dimOp.getConstantIndex();
1900  if (!dim.has_value())
1901  return failure();
1902 
1903  // Skip static dims. These are folded to constant ops.
1904  RankedTensorType resultType = expandShapeOp.getResultType();
1905  if (!resultType.isDynamicDim(*dim))
1906  return failure();
1907 
1908  // Find reassociation group that contains this result dimension.
1909  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1910 
1911  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1912  // ExpandShapeOp would be ambiguous.)
1913  int64_t product = 1;
1914  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1915  for (int64_t d : grp) {
1916  if (d != dim) {
1917  assert(!resultType.isDynamicDim(d) && "expected static dim");
1918  product *= resultType.getDimSize(d);
1919  }
1920  }
1921 
1922  // result dim size = src dim size / (product(other dims in reassoc group))
1923  Value srcDimSz =
1924  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1925  AffineExpr expr;
1926  bindSymbols(dimOp.getContext(), expr);
1927  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1928  dimOp, expr.floorDiv(product), srcDimSz);
1929  return success();
1930  }
1931 };
1932 
1933 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1935 
1936  LogicalResult matchAndRewrite(DimOp dimOp,
1937  PatternRewriter &rewriter) const override {
1938  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1939  if (!collapseShapeOp)
1940  return failure();
1941 
1942  // Only constant dimension values are supported.
1943  std::optional<int64_t> dim = dimOp.getConstantIndex();
1944  if (!dim.has_value())
1945  return failure();
1946 
1947  // Skip static dims. These are folded to constant ops.
1948  RankedTensorType resultType = collapseShapeOp.getResultType();
1949  if (!resultType.isDynamicDim(*dim))
1950  return failure();
1951 
1952  // Get reassociation group of the result dimension.
1953  ReassociationIndices group =
1954  collapseShapeOp.getReassociationIndices()[*dim];
1955 
1956  // result dim size = product(dims in reassoc group)
1957  SmallVector<Value> srcDimSizes;
1960  for (const auto &it : llvm::enumerate(group)) {
1961  srcDimSizes.push_back(rewriter.create<DimOp>(
1962  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1963  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1964  product = product ? product * syms.back() : syms.back();
1965  }
1966  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1967  srcDimSizes);
1968  return success();
1969  }
1970 };
1971 } // namespace
1972 
1973 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1974  MLIRContext *context) {
1975  results.add<
1978  FoldReshapeWithConstant<ExpandShapeOp>,
1979  FoldReshapeWithSplat<ExpandShapeOp>,
1980  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1981  FoldDimOfCollapseShape>(context);
1982 }
1983 
1984 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1985  MLIRContext *context) {
1986  results.add<
1988  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
1989  tensor::DimOp, RankedTensorType>,
1990  FoldReshapeWithConstant<CollapseShapeOp>,
1991  FoldReshapeWithSplat<CollapseShapeOp>,
1992  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1993  context);
1994 }
1995 
1996 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1997  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
1998  adaptor.getOperands());
1999 }
2000 
2001 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2002  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2003  adaptor.getOperands());
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // ExtractSliceOp
2008 //===----------------------------------------------------------------------===//
2009 
2010 void ExtractSliceOp::getAsmResultNames(
2011  function_ref<void(Value, StringRef)> setNameFn) {
2012  setNameFn(getResult(), "extracted_slice");
2013 }
2014 
2015 /// An extract_slice result type can be inferred, when it is not
2016 /// rank-reduced, from the source type and the static representation of
2017 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2018 RankedTensorType ExtractSliceOp::inferResultType(
2019  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2020  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2021  // An extract_slice op may specify only a leading subset of offset/sizes/
2022  // strides in which case we complete with offset=0, sizes from memref type
2023  // and strides=1.
2024  assert(static_cast<int64_t>(staticSizes.size()) ==
2025  sourceTensorType.getRank() &&
2026  "unexpected staticSizes not equal to rank of source");
2027  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2028  sourceTensorType.getEncoding());
2029 }
2030 
2031 RankedTensorType ExtractSliceOp::inferResultType(
2032  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2034  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2035  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2036  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2037  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2038  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2039  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2040  staticSizes, staticStrides);
2041 }
2042 
2043 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2044 /// number of sizes), drop as many size 1 as needed to produce an inferred
2045 /// type with the desired rank.
2046 ///
2047 /// Note that there may be multiple ways to compute this rank-reduced type:
2048 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2049 ///
2050 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2051 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2052  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2053  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2054  ArrayRef<int64_t> strides) {
2055  // Type inferred in the absence of rank-reducing behavior.
2056  auto inferredType = llvm::cast<RankedTensorType>(
2057  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2058  int rankDiff = inferredType.getRank() - desiredResultRank;
2059  if (rankDiff > 0) {
2060  auto shape = inferredType.getShape();
2061  llvm::SmallBitVector dimsToProject =
2062  getPositionsOfShapeOne(rankDiff, shape);
2063  SmallVector<int64_t> projectedShape;
2064  // Best effort rank-reducing: drop 1s in order.
2065  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2066  if (!dimsToProject.test(pos))
2067  projectedShape.push_back(shape[pos]);
2068  inferredType =
2069  RankedTensorType::get(projectedShape, inferredType.getElementType());
2070  }
2071  return inferredType;
2072 }
2073 
2074 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2075  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2077  ArrayRef<OpFoldResult> strides) {
2078  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2079  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2080  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2081  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2082  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2083  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2084  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2085  staticStrides);
2086 }
2087 
2088 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2089 /// result type. If the type passed is nullptr, it is inferred.
2090 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2091  RankedTensorType resultType, Value source,
2092  ArrayRef<OpFoldResult> offsets,
2093  ArrayRef<OpFoldResult> sizes,
2094  ArrayRef<OpFoldResult> strides,
2095  ArrayRef<NamedAttribute> attrs) {
2096  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2097  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2098  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2099  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2100  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2101  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2102  // Structuring implementation this way avoids duplication between builders.
2103  if (!resultType) {
2104  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2105  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2106  }
2107  result.addAttributes(attrs);
2108  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2109  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2110  b.getDenseI64ArrayAttr(staticSizes),
2111  b.getDenseI64ArrayAttr(staticStrides));
2112 }
2113 
2114 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2115 /// result type.
2116 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2117  ArrayRef<OpFoldResult> offsets,
2118  ArrayRef<OpFoldResult> sizes,
2119  ArrayRef<OpFoldResult> strides,
2120  ArrayRef<NamedAttribute> attrs) {
2121  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2122 }
2123 
2124 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2125 /// a Range vector.
2126 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2127  ArrayRef<Range> ranges,
2128  ArrayRef<NamedAttribute> attrs) {
2129  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2130  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2131 }
2132 
2133 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2134 /// the type passed is nullptr, it is inferred.
2135 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2136  RankedTensorType resultType, Value source,
2137  ValueRange offsets, ValueRange sizes,
2138  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2139  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2140  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2141  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2142  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2143  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2144  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2145  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2146 }
2147 
2148 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2149 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2150  ValueRange offsets, ValueRange sizes,
2151  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2152  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2153 }
2154 
2156  Operation *op,
2157  RankedTensorType expectedType) {
2158  switch (result) {
2160  return success();
2162  return op->emitError("expected rank to be smaller or equal to ")
2163  << "the other rank. ";
2165  return op->emitError("expected type to be ")
2166  << expectedType << " or a rank-reduced version. (size mismatch) ";
2168  return op->emitError("expected element type to be ")
2169  << expectedType.getElementType();
2170  default:
2171  llvm_unreachable("unexpected extract_slice op verification result");
2172  }
2173 }
2174 
2175 /// Verifier for ExtractSliceOp.
2177  // Verify result type against inferred type.
2178  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2179  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2180  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2181  return produceSliceErrorMsg(result, *this, expectedType);
2182 }
2183 
2184 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2186 }
2187 
2189 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2190  ArrayRef<int64_t> desiredShape) {
2191  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2192  assert(sourceTensorType && "not a ranked tensor type");
2193  auto sourceShape = sourceTensorType.getShape();
2194  if (sourceShape.equals(desiredShape))
2195  return value;
2196  auto maybeRankReductionMask =
2197  mlir::computeRankReductionMask(sourceShape, desiredShape);
2198  if (!maybeRankReductionMask)
2199  return failure();
2201  b, loc, value,
2202  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2203 }
2204 
2206  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2207  reifiedReturnShapes.resize(1);
2208  reifiedReturnShapes[0].reserve(getType().getRank());
2210  llvm::SmallBitVector droppedDims = getDroppedDims();
2211  for (const auto &size : enumerate(mixedSizes)) {
2212  if (droppedDims.test(size.index()))
2213  continue;
2214  reifiedReturnShapes[0].push_back(size.value());
2215  }
2216  return success();
2217 }
2218 
2219 namespace {
2220 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2221 /// This essentially pushes memref_cast past its consuming slice when
2222 /// `canFoldIntoConsumerOp` is true.
2223 ///
2224 /// Example:
2225 /// ```
2226 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2227 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2228 /// tensor<3x4xf32>
2229 /// ```
2230 /// is rewritten into:
2231 /// ```
2232 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2233 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2234 /// ```
2235 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2236 public:
2238 
2239  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2240  PatternRewriter &rewriter) const override {
2241  // Any constant operand, just return to let the constant folder kick in.
2242  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2243  return matchPattern(operand, matchConstantIndex());
2244  }))
2245  return failure();
2246 
2247  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2248  if (!castOp)
2249  return failure();
2250 
2251  if (!canFoldIntoConsumerOp(castOp))
2252  return failure();
2253 
2254  // Create folded extract.
2255  Location loc = sliceOp.getLoc();
2256  Value newResult = rewriter.create<ExtractSliceOp>(
2257  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2258  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2259  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2260  if (newResult.getType() != sliceOp.getType())
2261  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2262  rewriter.replaceOp(sliceOp, newResult);
2263  return success();
2264  }
2265 };
2266 
2267 /// Slice elements from `values` into `outValues`. `counts` represents the
2268 /// numbers of elements to stride in the original values for each dimension.
2269 /// The output values can be used to construct a DenseElementsAttr.
2270 template <typename IterTy, typename ElemTy>
2271 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2272  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2273  ArrayRef<int64_t> strides,
2274  llvm::SmallVectorImpl<ElemTy> *outValues) {
2275  assert(offsets.size() == sizes.size());
2276  assert(offsets.size() == strides.size());
2277  if (offsets.empty())
2278  return;
2279 
2280  int64_t offset = offsets.front();
2281  int64_t size = sizes.front();
2282  int64_t stride = strides.front();
2283  if (offsets.size() == 1) {
2284  for (int64_t i = 0; i < size; ++i, offset += stride)
2285  outValues->push_back(*(values + offset));
2286 
2287  return;
2288  }
2289 
2290  for (int64_t i = 0; i < size; ++i, offset += stride) {
2291  auto begin = values + offset * counts.front();
2292  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2293  offsets.drop_front(), sizes.drop_front(),
2294  strides.drop_front(), outValues);
2295  }
2296 }
2297 
2298 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2299 /// folded operation might introduce more constant data; Users can control
2300 /// their heuristics by the control function.
2301 class ConstantOpExtractSliceFolder final
2302  : public OpRewritePattern<ExtractSliceOp> {
2303 public:
2305 
2306  ConstantOpExtractSliceFolder(MLIRContext *context,
2308  : OpRewritePattern<ExtractSliceOp>(context),
2309  controlFn(std::move(controlFn)) {}
2310 
2311  LogicalResult matchAndRewrite(ExtractSliceOp op,
2312  PatternRewriter &rewriter) const override {
2313  DenseElementsAttr attr;
2314  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2315  return failure();
2316 
2317  // A constant splat is handled by fold().
2318  if (attr.isSplat())
2319  return failure();
2320 
2321  // Dynamic result shape is not supported.
2322  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2323  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2324  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2325  return failure();
2326 
2327  // Customized control over the folding.
2328  if (!controlFn(op))
2329  return failure();
2330 
2331  int64_t count = sourceType.getNumElements();
2332  if (count == 0)
2333  return failure();
2334 
2335  // Check if there are any dynamic parts, which are not supported.
2336  auto offsets = op.getStaticOffsets();
2337  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2338  return failure();
2339  auto sizes = op.getStaticSizes();
2340  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2341  return failure();
2342  auto strides = op.getStaticStrides();
2343  if (llvm::is_contained(strides, ShapedType::kDynamic))
2344  return failure();
2345 
2346  // Compute the stride for each dimension.
2347  SmallVector<int64_t> counts;
2348  ArrayRef<int64_t> shape = sourceType.getShape();
2349  counts.reserve(shape.size());
2350  for (int64_t v : shape) {
2351  count = count / v;
2352  counts.push_back(count);
2353  }
2354 
2355  // New attribute constructed by the sliced values.
2356  DenseElementsAttr newAttr;
2357 
2358  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2359  SmallVector<APInt> outValues;
2360  outValues.reserve(sourceType.getNumElements());
2361  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2362  elems.begin(), counts, offsets, sizes, strides, &outValues);
2363  newAttr = DenseElementsAttr::get(resultType, outValues);
2364  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2365  SmallVector<APFloat> outValues;
2366  outValues.reserve(sourceType.getNumElements());
2367  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2368  elems.begin(), counts, offsets, sizes, strides, &outValues);
2369  newAttr = DenseElementsAttr::get(resultType, outValues);
2370  }
2371 
2372  if (newAttr) {
2373  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2374  return success();
2375  }
2376 
2377  return failure();
2378  }
2379 
2380 private:
2381  /// This additionally controls whether the fold happens or not. Users can
2382  /// impose their heuristics in the function.
2384 };
2385 
2386 } // namespace
2387 
2389  RewritePatternSet &patterns,
2390  const ControlConstantExtractSliceFusionFn &controlFn) {
2391  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2392 }
2393 
2394 /// Return the canonical type of the result of an extract_slice op.
2396  RankedTensorType operator()(ExtractSliceOp op,
2397  ArrayRef<OpFoldResult> mixedOffsets,
2398  ArrayRef<OpFoldResult> mixedSizes,
2399  ArrayRef<OpFoldResult> mixedStrides) {
2400  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2401  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2402  mixedStrides);
2403  }
2404 };
2405 
2406 /// A canonicalizer wrapper to replace ExtractSliceOps.
2408  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2409  ExtractSliceOp newOp) {
2410  Value replacement = newOp.getResult();
2411  if (replacement.getType() != op.getType())
2412  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2413  replacement);
2414  rewriter.replaceOp(op, replacement);
2415  }
2416 };
2417 
2418 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2419  MLIRContext *context) {
2420  results.add<
2423  ExtractSliceOpCastFolder>(context);
2424 }
2425 
2426 //
2427 static LogicalResult
2428 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2429  ShapedType shapedType) {
2430  OpBuilder b(op.getContext());
2431  for (OpFoldResult ofr : op.getMixedOffsets())
2432  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2433  return failure();
2434  // Rank-reducing noops only need to inspect the leading dimensions:
2435  // llvm::zip is appropriate.
2436  auto shape = shapedType.getShape();
2437  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2438  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2439  return failure();
2440  for (OpFoldResult ofr : op.getMixedStrides())
2441  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2442  return failure();
2443  return success();
2444 }
2445 
2446 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2447 /// slice, we can return the InsertSliceOp's source directly.
2448 // TODO: This only checks the immediate producer; extend to go up the
2449 // insert/extract chain if the slices are disjoint.
2450 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2451  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2452 
2453  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2454  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2455  insertOp.isSameAs(extractOp, isSame))
2456  return insertOp.getSource();
2457 
2458  return {};
2459 }
2460 
2461 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2462  if (OpFoldResult reshapedSource = reshapeConstantSource(
2463  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2464  getResult().getType()))
2465  return reshapedSource;
2466  if (getSourceType() == getType() &&
2468  return this->getSource();
2469  if (Value slice = foldExtractAfterInsertSlice(*this))
2470  return slice;
2471 
2472  return OpFoldResult();
2473 }
2474 
2476  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2477  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2478  unsigned rank = rankedTensorType.getRank();
2479  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2480  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2481  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2482  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2483  offsets, sizes, strides);
2484 }
2485 
2486 //===----------------------------------------------------------------------===//
2487 // InsertSliceOp
2488 //===----------------------------------------------------------------------===//
2489 
2490 void InsertSliceOp::getAsmResultNames(
2491  function_ref<void(Value, StringRef)> setNameFn) {
2492  setNameFn(getResult(), "inserted_slice");
2493 }
2494 
2495 // Build a InsertSliceOp with mixed static and dynamic entries.
2496 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2497  Value dest, ArrayRef<OpFoldResult> offsets,
2498  ArrayRef<OpFoldResult> sizes,
2499  ArrayRef<OpFoldResult> strides,
2500  ArrayRef<NamedAttribute> attrs) {
2501  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2502  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2503  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2504  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2505  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2506  result.addAttributes(attrs);
2507  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2508  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2509  b.getDenseI64ArrayAttr(staticSizes),
2510  b.getDenseI64ArrayAttr(staticStrides));
2511 }
2512 
2513 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2514 /// Range vector.
2515 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2516  Value dest, ArrayRef<Range> ranges,
2517  ArrayRef<NamedAttribute> attrs) {
2518  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2519  build(b, result, source, dest, offsets, sizes, strides, attrs);
2520 }
2521 
2522 // Build a InsertSliceOp with dynamic entries.
2523 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2524  Value dest, ValueRange offsets, ValueRange sizes,
2525  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2526  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2527  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2528  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2529  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2530  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2531  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2532  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2533 }
2534 
2535 /// Rank-reducing type verification for both InsertSliceOp and
2536 /// ParallelInsertSliceOp.
2538  RankedTensorType srcType, RankedTensorType dstType,
2539  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2540  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2541  // insert_slice is the inverse of extract_slice, use the same type
2542  // inference.
2543  RankedTensorType expected = ExtractSliceOp::inferResultType(
2544  dstType, staticOffsets, staticSizes, staticStrides);
2545  if (expectedType)
2546  *expectedType = expected;
2547  return isRankReducedType(expected, srcType);
2548 }
2549 
2550 /// Verifier for InsertSliceOp.
2552  RankedTensorType expectedType;
2553  SliceVerificationResult result =
2554  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2555  getStaticSizes(), getStaticStrides(), &expectedType);
2556  return produceSliceErrorMsg(result, *this, expectedType);
2557 }
2558 
2559 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2560 /// can mutate the second InsertSliceOp's destination to the first one's.
2561 ///
2562 /// Example:
2563 ///
2564 /// ```mlir
2565 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2566 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2567 /// ```
2568 ///
2569 /// folds into:
2570 ///
2571 /// ```mlir
2572 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2573 /// ```
2574 ///
2575 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2576 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2577  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2578 
2579  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2580  if (!prevInsertOp ||
2581  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2582  !prevInsertOp.isSameAs(insertOp, isSame))
2583  return failure();
2584 
2585  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2586  return success();
2587 }
2588 
2589 /// Folds round-trip extract/insert slice op pairs.
2590 /// Example:
2591 /// ```mlir
2592 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2593 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2594 /// ```
2595 /// can be folded into %val.
2596 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2597  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2598 
2599  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2600  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2601  !extractOp.isSameAs(insertOp, isSame))
2602  return nullptr;
2603 
2604  return extractOp.getSource();
2605 }
2606 
2607 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2608  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2609  getSourceType() == getType() &&
2611  return this->getSource();
2613  return getResult();
2614  if (auto result = foldInsertAfterExtractSlice(*this))
2615  return result;
2616  if (llvm::any_of(getMixedSizes(),
2617  [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2618  return getDest();
2619  return OpFoldResult();
2620 }
2621 
2623  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2624  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2625  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2626  return success();
2627 }
2628 
2629 namespace {
2630 /// Pattern to rewrite a insert_slice op with constant arguments.
2631 ///
2632 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2633 template <typename InsertOpTy>
2634 class InsertSliceOpConstantArgumentFolder final
2635  : public OpRewritePattern<InsertOpTy> {
2636 public:
2638 
2639  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2640  PatternRewriter &rewriter) const override {
2641  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2642  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2643  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2644 
2645  // No constant operands were folded, just return;
2646  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2647  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2648  failed(foldDynamicStrideList(mixedStrides)))
2649  return failure();
2650 
2651  // Create the new op in canonical form.
2652  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2653  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2654  mixedOffsets, mixedSizes, mixedStrides);
2655  Value toInsert = insertSliceOp.getSource();
2656  if (sourceType != insertSliceOp.getSourceType()) {
2657  OpBuilder::InsertionGuard g(rewriter);
2658  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2659  // is that the insertion point is just before the ParallelCombiningOp in
2660  // the parallel case.
2661  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2662  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2663  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2664  sourceType, toInsert);
2665  }
2666  rewriter.replaceOpWithNewOp<InsertOpTy>(
2667  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2668  mixedSizes, mixedStrides);
2669  return success();
2670  }
2671 };
2672 
2673 /// Fold tensor_casts with insert_slice operations. If the source or
2674 /// destination tensor is a tensor_cast that removes static type information,
2675 /// the cast is folded into the insert_slice operation. E.g.:
2676 ///
2677 /// ```mlir
2678 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2679 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2680 /// ```
2681 ///
2682 /// folds into:
2683 ///
2684 /// ```mlir
2685 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2686 /// ```
2687 ///
2688 /// Note: When folding a cast on the destination tensor, the result of the
2689 /// insert_slice operation is casted to ensure that the type of the result did
2690 /// not change.
2691 ///
2692 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2693 template <typename InsertOpTy>
2694 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2696 
2697  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2698  PatternRewriter &rewriter) const override {
2699  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2700  return matchPattern(operand, matchConstantIndex());
2701  }))
2702  return failure();
2703 
2704  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2705  auto castOp = v.getDefiningOp<tensor::CastOp>();
2706  if (!castOp || !canFoldIntoConsumerOp(castOp))
2707  return std::nullopt;
2708  return castOp.getSource();
2709  };
2710  std::optional<Value> sourceCastSource =
2711  getSourceOfCastOp(insertSliceOp.getSource());
2712  std::optional<Value> destCastSource =
2713  getSourceOfCastOp(insertSliceOp.getDest());
2714  if (!sourceCastSource && !destCastSource)
2715  return failure();
2716 
2717  auto src =
2718  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2719  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2720  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2721  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2722  if (!srcType || !dstType)
2723  return failure();
2724 
2725  // The tensor.cast source could have additional static information not seen
2726  // in the insert slice op static sizes, so we ignore dynamic dims when
2727  // computing the rank reduction mask.
2728  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2729  auto rankReductionMask = computeRankReductionMask(
2730  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2731  if (!rankReductionMask.has_value())
2732  return failure();
2733  // Replace dimensions in the insert slice op with corresponding static dims
2734  // from the cast source type. If the insert slice sizes have static dims
2735  // that are not static in the tensor.cast source (i.e., when the cast op
2736  // casts a dynamic dim to static), the dim should not be replaced, and the
2737  // pattern will fail later in `verifyInsertSliceOp`.
2738  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2739  int64_t rankReducedIdx = 0;
2740  for (auto [idx, size] : enumerate(staticSizes)) {
2741  if (!rankReductionMask.value().contains(idx) &&
2742  !srcType.isDynamicDim(rankReducedIdx)) {
2743  mixedSizes[idx] = getAsIndexOpFoldResult(
2744  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2745  size = srcType.getDimSize(rankReducedIdx++);
2746  }
2747  }
2748  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2749  staticSizes, insertSliceOp.getStaticStrides()) !=
2751  return failure();
2752 
2753  Operation *replacement = rewriter.create<InsertOpTy>(
2754  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2755  mixedSizes, insertSliceOp.getMixedStrides());
2756 
2757  // In the parallel case there is no result and so nothing to cast.
2758  bool isParallelInsert =
2759  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2760  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2761  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2762  insertSliceOp.getDestType(),
2763  replacement->getResult(0));
2764  }
2765  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2766  return success();
2767  }
2768 };
2769 
2770 /// If additional static type information can be deduced from a insert_slice's
2771 /// size operands, insert an explicit cast of the op's source operand. This
2772 /// enables other canonicalization patterns that are matching for tensor_cast
2773 /// ops such as `ForOpTensorCastFolder` in SCF.
2774 ///
2775 /// Example:
2776 ///
2777 /// ```mlir
2778 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2779 /// : tensor<?x?xf32> into ...
2780 /// ```
2781 ///
2782 /// folds into:
2783 ///
2784 /// ```mlir
2785 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2786 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2787 /// : tensor<64x64xf32> into ...
2788 /// ```
2789 ///
2790 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2791 template <typename InsertOpTy>
2792 struct InsertSliceOpSourceCastInserter final
2793  : public OpRewritePattern<InsertOpTy> {
2795 
2796  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2797  PatternRewriter &rewriter) const override {
2798  RankedTensorType srcType = insertSliceOp.getSourceType();
2799  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2800  return failure();
2801  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2802  srcType.getShape().end());
2803  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2804  if (std::optional<int64_t> constInt =
2805  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2806  // Bail on invalid IR.
2807  if (*constInt < 0)
2808  return failure();
2809  newSrcShape[i] = *constInt;
2810  }
2811  }
2812  if (!hasValidSizesOffsets(newSrcShape))
2813  return failure();
2814 
2815  RankedTensorType newSrcType = RankedTensorType::get(
2816  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2817  if (srcType == newSrcType ||
2818  !preservesStaticInformation(srcType, newSrcType) ||
2819  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2820  return failure();
2821 
2822  // newSrcType is:
2823  // 1) Different from srcType.
2824  // 2) "More static" than srcType.
2825  // 3) Cast-compatible with srcType.
2826  // Insert the cast.
2827  OpBuilder::InsertionGuard g(rewriter);
2828  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2829  // that the insertion point is just before the ParallelCombiningOp in the
2830  // parallel case.
2831  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2832  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2833  Value cast = rewriter.create<tensor::CastOp>(
2834  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2835  rewriter.replaceOpWithNewOp<InsertOpTy>(
2836  insertSliceOp, cast, insertSliceOp.getDest(),
2837  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2838  insertSliceOp.getMixedStrides());
2839  return success();
2840  }
2841 };
2842 } // namespace
2843 
2844 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2845  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2846 }
2847 
2848 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2849  MLIRContext *context) {
2850  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2851  InsertSliceOpCastFolder<InsertSliceOp>,
2852  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2853 }
2854 
2856  Location loc,
2857  Value tensor,
2858  Value dest) {
2859  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2860  unsigned rank = rankedTensorType.getRank();
2861  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2862  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2863  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2864  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2865  sizes, strides);
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // PadOp
2870 //===----------------------------------------------------------------------===//
2871 
2872 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2873  setNameFn(getResult(), "padded");
2874 }
2875 
2876 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2877 // supports optional types.
2878 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2879  Type typeToInfer, Type typeToInferFrom) {}
2880 
2883  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2884  Type &typeToInfer, Type typeToInferFrom) {
2885  if (optOperand)
2886  typeToInfer = typeToInferFrom;
2887  return success();
2888 }
2889 
2891  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2892  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2893  auto expectedType =
2894  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2895  if (!expectedType) {
2896  return emitError("failed to infer expectedType from sourceType ")
2897  << sourceType << ", specified resultType is " << resultType;
2898  }
2899  if (resultType.getRank() != expectedType.getRank()) {
2900  return emitError("specified type ")
2901  << resultType << " does not match the inferred type "
2902  << expectedType;
2903  }
2904  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2905  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2906  continue;
2907  if (expectedType.isDynamicDim(i))
2908  continue;
2909  return emitError("specified type ")
2910  << resultType << " does not match the inferred type "
2911  << expectedType;
2912  }
2913 
2914  return success();
2915 }
2916 
2917 LogicalResult PadOp::verifyRegions() {
2918  auto &region = getRegion();
2919  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2920  Block &block = region.front();
2921  if (block.getNumArguments() != rank)
2922  return emitError("expected the block to have ") << rank << " arguments";
2923 
2924  // Note: the number and type of yield values are checked in the YieldOp.
2925  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2926  if (!en.value().isIndex())
2927  return emitOpError("expected block argument ")
2928  << (en.index() + 1) << " to be an index";
2929  }
2930 
2931  // Ensure that the region yields an element of the right type.
2932  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2933  if (yieldOp.getValue().getType() !=
2934  llvm::cast<ShapedType>(getType()).getElementType())
2935  return emitOpError("expected yield type to match shape element type");
2936 
2937  return success();
2938 }
2939 
2940 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2941  ArrayRef<int64_t> staticLow,
2942  ArrayRef<int64_t> staticHigh,
2943  ArrayRef<int64_t> resultShape) {
2944  unsigned rank = sourceType.getRank();
2945  if (staticLow.size() != rank)
2946  return RankedTensorType();
2947  if (staticHigh.size() != rank)
2948  return RankedTensorType();
2949  if (!resultShape.empty() && resultShape.size() != rank)
2950  return RankedTensorType();
2951 
2952  SmallVector<int64_t, 4> inferredShape;
2953  for (auto i : llvm::seq<unsigned>(0, rank)) {
2954  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2955  staticHigh[i] == ShapedType::kDynamic) {
2956  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2957  : resultShape[i]);
2958  } else {
2959  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2960  assert((resultShape.empty() || size == resultShape[i] ||
2961  resultShape[i] == ShapedType::kDynamic) &&
2962  "mismatch between inferred shape and result shape");
2963  inferredShape.push_back(size);
2964  }
2965  }
2966 
2967  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2968 }
2969 
2970 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2971  Value source, ArrayRef<int64_t> staticLow,
2972  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2973  bool nofold, ArrayRef<NamedAttribute> attrs) {
2974  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2975  if (!resultType)
2976  resultType = inferResultType(sourceType, staticLow, staticHigh);
2977  result.addAttributes(attrs);
2978  build(b, result, resultType, source, low, high,
2979  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2980  nofold ? b.getUnitAttr() : UnitAttr());
2981 }
2982 
2983 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2984  Value source, ValueRange low, ValueRange high, bool nofold,
2985  ArrayRef<NamedAttribute> attrs) {
2986  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2987  unsigned rank = sourceType.getRank();
2988  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2989  build(b, result, resultType, source, staticVector, staticVector, low, high,
2990  nofold, attrs);
2991 }
2992 
2993 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2994  Value source, ArrayRef<OpFoldResult> low,
2995  ArrayRef<OpFoldResult> high, bool nofold,
2996  ArrayRef<NamedAttribute> attrs) {
2997  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2998  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2999  SmallVector<int64_t, 4> staticLow, staticHigh;
3000  // staticLow and staticHigh have full information of the padding config.
3001  // This will grow staticLow and staticHigh with 1 value. If the config is
3002  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3003  // value as well.
3004  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3005  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3006  if (!resultType) {
3007  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3008  }
3009  assert(llvm::isa<RankedTensorType>(resultType));
3010  result.addAttributes(attrs);
3011  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3012  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3013  nofold ? b.getUnitAttr() : UnitAttr());
3014 }
3015 
3016 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3017  Value source, ArrayRef<OpFoldResult> low,
3018  ArrayRef<OpFoldResult> high, Value constantPadValue,
3019  bool nofold, ArrayRef<NamedAttribute> attrs) {
3020  build(b, result, resultType, source, low, high, nofold, attrs);
3021 
3022  // Add a region and a block to yield the pad value.
3023  Region *region = result.regions[0].get();
3024  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3025  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3026  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3027 
3028  // `builder.createBlock` changes the insertion point within the block. Create
3029  // a guard to reset the insertion point of the builder after it is destroyed.
3030  OpBuilder::InsertionGuard guard(b);
3031  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3032  b.create<tensor::YieldOp>(result.location, constantPadValue);
3033 }
3034 
3035 llvm::SmallBitVector PadOp::getPaddedDims() {
3036  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3037  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3038  for (const auto &en : enumerate(paddingWidths))
3039  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3040  paddedDims.set(en.index());
3041  };
3042  extractPaddedDims(getMixedLowPad());
3043  extractPaddedDims(getMixedHighPad());
3044  return paddedDims;
3045 }
3046 
3047 namespace {
3048 // Folds tensor.pad when padding is static zeros and the attribute
3049 // doesn't request otherwise.
3050 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3052 
3053  LogicalResult matchAndRewrite(PadOp padTensorOp,
3054  PatternRewriter &rewriter) const override {
3055  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3056  return failure();
3057  if (padTensorOp.getNofold())
3058  return failure();
3059  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3060  padTensorOp, padTensorOp.getResult().getType(),
3061  padTensorOp.getSource());
3062  return success();
3063  }
3064 };
3065 
3066 // Fold CastOp into PadOp when adding static information.
3067 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3069 
3070  LogicalResult matchAndRewrite(PadOp padTensorOp,
3071  PatternRewriter &rewriter) const override {
3072  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3073  if (!tensor::canFoldIntoConsumerOp(castOp))
3074  return failure();
3075 
3076  auto newResultType = PadOp::inferResultType(
3077  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3078  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3079  padTensorOp.getResultType().getShape());
3080 
3081  if (newResultType == padTensorOp.getResultType()) {
3082  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3083  padTensorOp.getSourceMutable().assign(castOp.getSource());
3084  });
3085  } else {
3086  auto newOp = rewriter.create<PadOp>(
3087  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3088  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3089  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3090  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3091  IRMapping mapper;
3092  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3093 
3094  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3095  padTensorOp, padTensorOp.getResultType(), newOp);
3096  }
3097  return success();
3098  }
3099 };
3100 
3101 // Fold CastOp using the result of PadOp back into the latter if it adds
3102 // static information.
3103 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3105 
3106  LogicalResult matchAndRewrite(PadOp padTensorOp,
3107  PatternRewriter &rewriter) const override {
3108  if (!padTensorOp.getResult().hasOneUse())
3109  return failure();
3110  auto tensorCastOp =
3111  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3112  if (!tensorCastOp)
3113  return failure();
3114  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3115  tensorCastOp.getDest().getType()))
3116  return failure();
3117 
3118  auto replacementOp = rewriter.create<PadOp>(
3119  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3120  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3121  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3122  padTensorOp.getHigh(), padTensorOp.getNofold(),
3123  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3124  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3125 
3126  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3127  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3128  return success();
3129  }
3130 };
3131 
3132 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3133 /// different dimensions. The pattern applies if the following preconditions
3134 /// hold:
3135 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3136 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3137 /// 3) the tensor::PadOps perform only high-padding,
3138 /// 4) the tensor::PadOps have the same constant padding value,
3139 /// 5) the tensor::PadOps do not have common padding dimensions,
3140 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3141 /// zero-offset for every dimension.
3142 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3143 /// the
3144 /// padded source dimensions.
3145 ///
3146 /// Example:
3147 ///
3148 /// ```mlir
3149 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3150 /// : tensor<64x64xf32> to tensor<?x64xf32>
3151 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3152 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3153 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3154 /// : tensor<8x64xf32> to tensor<8x?xf32>
3155 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3156 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3157 /// ```
3158 ///
3159 /// folds into:
3160 ///
3161 /// ```mlir
3162 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3163 /// : tensor<64x64xf32> to tensor<?x?xf32>
3164 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3165 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3166 /// ```
3167 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3169 
3170  LogicalResult matchAndRewrite(PadOp padOp,
3171  PatternRewriter &rewriter) const override {
3172  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3173  if (!innerSliceOp)
3174  return failure();
3175  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3176  if (!outerPadOp || outerPadOp.getNofold())
3177  return failure();
3178  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3179  if (!outerSliceOp)
3180  return failure();
3181 
3182  // 1) Fail if the chain is rank-reducing.
3183  int64_t rank = padOp.getSourceType().getRank();
3184  if (outerSliceOp.getSourceType().getRank() != rank) {
3185  return rewriter.notifyMatchFailure(padOp,
3186  "cannot fold rank-reducing chain");
3187  }
3188 
3189  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3190  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3191  return rewriter.notifyMatchFailure(
3192  padOp, "cannot fold non-unit stride ExtractSliceOps");
3193  }
3194 
3195  // 3) Fail if the tensor::PadOps have non-zero low padding.
3196  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3197  return rewriter.notifyMatchFailure(padOp,
3198  "cannot fold PadOps with low padding");
3199  }
3200 
3201  // 4) Fail if the tensor::PadOps padding values do not match.
3202  Attribute innerAttr, outerAttr;
3203  Value innerValue = padOp.getConstantPaddingValue();
3204  Value outerValue = outerPadOp.getConstantPaddingValue();
3205  if (!innerValue || !outerValue ||
3206  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3207  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3208  innerAttr != outerAttr) {
3209  return rewriter.notifyMatchFailure(
3210  padOp, "cannot fold PadOps with different padding values");
3211  }
3212 
3213  // 5) Fail if a dimension is padded by both tensor::PadOps.
3214  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3215  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3216  if (innerDims.anyCommon(outerDims)) {
3217  return rewriter.notifyMatchFailure(
3218  padOp, "cannot fold PadOps with common padding dimensions");
3219  }
3220 
3221  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3222  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3223  // for every dimension, and use the offset the other pair. Fail if no
3224  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3225  // exists.
3226  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3227  for (auto en : enumerate(newOffsets)) {
3228  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3229  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3230  if (!innerDims.test(en.index()) &&
3231  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3232  en.value() = outerOffset;
3233  continue;
3234  }
3235  if (!outerDims.test(en.index()) &&
3236  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3237  en.value() = innerOffset;
3238  continue;
3239  }
3240  return rewriter.notifyMatchFailure(
3241  padOp, "cannot find zero-offset and zero-padding pair");
3242  }
3243 
3244  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3245  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3246  // outer tensor::PadOp and fail if the size of the inner
3247  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3248  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3249  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3250  for (auto en : enumerate(newSizes)) {
3251  if (!outerDims.test(en.index()))
3252  continue;
3253  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3254  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3255  assert(!ShapedType::isDynamic(sourceSize) &&
3256  "expected padded dimension to have a static size");
3257  if (getConstantIntValue(sliceSize) != sourceSize) {
3258  return rewriter.notifyMatchFailure(
3259  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3260  "match the size of the outer padding");
3261  }
3262  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3263  }
3264 
3265  // Combine the high paddings of the two tensor::PadOps.
3266  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3267  for (auto en : enumerate(newHighPad)) {
3268  if (innerDims.test(en.index()))
3269  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3270  if (outerDims.test(en.index()))
3271  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3272  }
3273 
3274  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3275  // the two paddings in one step.
3276  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3277  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3278  innerSliceOp.getMixedStrides());
3279  auto newPadOp = rewriter.create<PadOp>(
3280  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3281  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3282  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3283  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3284  newPadOp.getRegion().begin());
3285  rewriter.replaceOp(padOp, newPadOp.getResult());
3286  return success();
3287  }
3288 };
3289 
3290 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3292 
3293  LogicalResult matchAndRewrite(PadOp padTensorOp,
3294  PatternRewriter &rewriter) const override {
3295  Value input = padTensorOp.getSource();
3296  if (!llvm::isa<RankedTensorType>(input.getType()))
3297  return failure();
3298  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3299  auto inputRank = inputDims.size();
3300 
3301  auto oldResultType =
3302  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3303  if (!oldResultType)
3304  return failure();
3305 
3306  auto outputDims = oldResultType.getShape();
3307 
3308  // Extract the static info from the high and low operands.
3309  SmallVector<int64_t> constOperandsLow;
3310  SmallVector<Value> newLows;
3311  for (auto operand : padTensorOp.getLow()) {
3312  APSInt intOp;
3313  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3314  constOperandsLow.push_back(ShapedType::kDynamic);
3315  newLows.push_back(operand);
3316  continue;
3317  }
3318  constOperandsLow.push_back(intOp.getExtValue());
3319  }
3320  SmallVector<int64_t> constOperandsHigh;
3321  SmallVector<Value> newHighs;
3322  for (auto operand : padTensorOp.getHigh()) {
3323  APSInt intOp;
3324  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3325  constOperandsHigh.push_back(ShapedType::kDynamic);
3326  newHighs.push_back(operand);
3327  continue;
3328  }
3329  constOperandsHigh.push_back(intOp.getExtValue());
3330  }
3331 
3332  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3333  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3334 
3335  // Verify the op is well-formed.
3336  if (inputDims.size() != outputDims.size() ||
3337  inputDims.size() != constLow.size() ||
3338  inputDims.size() != constHigh.size())
3339  return failure();
3340 
3341  auto lowCount = 0;
3342  auto highCount = 0;
3343  for (size_t i = 0; i < inputRank; i++) {
3344  if (constLow[i] == ShapedType::kDynamic)
3345  constLow[i] = constOperandsLow[lowCount++];
3346  if (constHigh[i] == ShapedType::kDynamic)
3347  constHigh[i] = constOperandsHigh[highCount++];
3348  }
3349 
3350  auto staticLow = ArrayRef<int64_t>(constLow);
3351  auto staticHigh = ArrayRef<int64_t>(constHigh);
3352 
3353  // Calculate the output sizes with the static information.
3354  SmallVector<int64_t> newOutDims;
3355  for (size_t i = 0; i < inputRank; i++) {
3356  if (outputDims[i] == ShapedType::kDynamic) {
3357  newOutDims.push_back(
3358  (staticLow[i] == ShapedType::kDynamic ||
3359  staticHigh[i] == ShapedType::kDynamic ||
3360  inputDims[i] == ShapedType::kDynamic
3361  ? ShapedType::kDynamic
3362  : inputDims[i] + staticLow[i] + staticHigh[i]));
3363  } else {
3364  newOutDims.push_back(outputDims[i]);
3365  }
3366  }
3367 
3368  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3369  llvm::all_of(newOutDims,
3370  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3371  return failure();
3372 
3373  // Rewrite the op using the new static type.
3374  auto newResultType = RankedTensorType::get(
3375  newOutDims, padTensorOp.getType().getElementType());
3376  auto newOp = rewriter.create<PadOp>(
3377  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3378  newLows, newHighs, padTensorOp.getNofold(),
3379  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3380 
3381  IRMapping mapper;
3382  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3383  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3384  newOp);
3385 
3386  return success();
3387  }
3388 };
3389 
3390 } // namespace
3391 
3392 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3393  MLIRContext *context) {
3394  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3395  FoldOrthogonalPaddings, FoldStaticPadding>(context);
3396 }
3397 
3398 /// Return the padding value of the PadOp if it constant. In this context,
3399 /// "constant" means an actual constant or "defined outside of the block".
3400 ///
3401 /// Values are considered constant in three cases:
3402 /// - A ConstantLike value.
3403 /// - A basic block argument from a different block.
3404 /// - A value defined outside of the block.
3405 ///
3406 /// If the padding value is not constant, an empty Value is returned.
3407 Value PadOp::getConstantPaddingValue() {
3408  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3409  if (!yieldOp)
3410  return {};
3411  Value padValue = yieldOp.getValue();
3412  // Check if yield value is a constant.
3413  if (matchPattern(padValue, m_Constant()))
3414  return padValue;
3415  // Check if yield value is defined inside the PadOp block.
3416  if (padValue.getParentBlock() == &getRegion().front())
3417  return {};
3418  // Else: Yield value defined outside of the PadOp block.
3419  return padValue;
3420 }
3421 
3422 OpFoldResult PadOp::fold(FoldAdaptor) {
3423  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3424  !getNofold())
3425  return getSource();
3426  return {};
3427 }
3428 
3429 //===----------------------------------------------------------------------===//
3430 // ParallelInsertSliceOp
3431 //===----------------------------------------------------------------------===//
3432 
3433 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3434  ParallelCombiningOpInterface parallelCombiningParent =
3435  getParallelCombiningParent();
3436  for (const auto &it :
3437  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3438  Operation &nextOp = it.value();
3439  if (&nextOp == getOperation())
3440  return parallelCombiningParent.getParentResult(it.index());
3441  }
3442  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3443 }
3444 
3445 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3446 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3447  Value source, Value dest,
3448  ArrayRef<OpFoldResult> offsets,
3449  ArrayRef<OpFoldResult> sizes,
3450  ArrayRef<OpFoldResult> strides,
3451  ArrayRef<NamedAttribute> attrs) {
3452  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3453  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3454  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3455  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3456  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3457  result.addAttributes(attrs);
3458  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3459  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3460  b.getDenseI64ArrayAttr(staticSizes),
3461  b.getDenseI64ArrayAttr(staticStrides));
3462 }
3463 
3464 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3465 /// packed into a Range vector.
3466 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3467  Value source, Value dest,
3468  ArrayRef<Range> ranges,
3469  ArrayRef<NamedAttribute> attrs) {
3470  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3471  build(b, result, source, dest, offsets, sizes, strides, attrs);
3472 }
3473 
3474 // Build a ParallelInsertSliceOp with dynamic entries.
3475 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3476  Value source, Value dest, ValueRange offsets,
3477  ValueRange sizes, ValueRange strides,
3478  ArrayRef<NamedAttribute> attrs) {
3479  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3480  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3481  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3482  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3483  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3484  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3485  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3486 }
3487 
3489  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3490  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3491  << *(getOperation()->getParentOp());
3492 
3493  RankedTensorType expectedType;
3494  SliceVerificationResult result =
3495  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3496  getStaticSizes(), getStaticStrides(), &expectedType);
3497  return produceSliceErrorMsg(result, *this, expectedType);
3498 }
3499 
3500 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3501  RewritePatternSet &results, MLIRContext *context) {
3502  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3503  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3504  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3505 }
3506 
3507 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3508  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3509 }
3510 
3511 //===----------------------------------------------------------------------===//
3512 // ScatterOp
3513 //===----------------------------------------------------------------------===//
3514 
3515 void ScatterOp::getAsmResultNames(
3516  function_ref<void(Value, StringRef)> setNameFn) {
3517  setNameFn(getResult(), "scatter");
3518 }
3519 
3521  int64_t destRank = getDestType().getRank();
3522  ArrayRef<int64_t> scatterDims = getScatterDims();
3523  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3524  "scatter", "dest")))
3525  return failure();
3526 
3527  if (!getUnique())
3528  return emitOpError("requires 'unique' attribute to be set");
3529  // TODO: we could also check statically that there are fewer leading index
3530  // tensor dims than the dest dims. If this is not the case, the unique
3531  // attribute cannot be true.
3532 
3533  // Use the GatherOp::inferResultType on the `dest` type and verify the
3534  // expected type matches the source type.
3535  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3536  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3537  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3538  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3539  if (getSourceType() != expectedSourceType &&
3540  getSourceType() != expectedRankReducedSourceType) {
3541  return emitOpError("source type "
3542  "mismatch: "
3543  "expected ")
3544  << expectedSourceType << " or its rank-reduced variant "
3545  << expectedRankReducedSourceType << " (got: " << getSourceType()
3546  << ")";
3547  }
3548 
3549  return success();
3550 }
3551 
3552 //===----------------------------------------------------------------------===//
3553 // SplatOp
3554 //===----------------------------------------------------------------------===//
3555 
3556 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3557  Type aggregateType, ValueRange dynamicSizes) {
3558  build(builder, result, aggregateType, element, dynamicSizes);
3559 }
3560 
3561 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3562  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3563  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3564  build(builder, result, aggregateType, element, dynamicSizes);
3565 }
3566 
3567 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3568  ArrayRef<OpFoldResult> sizes) {
3569  SmallVector<int64_t> staticShape;
3570  SmallVector<Value> dynamicSizes;
3571  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3572  build(builder, result, element, staticShape, dynamicSizes);
3573 }
3574 
3575 void SplatOp::getAsmResultNames(
3576  function_ref<void(Value, StringRef)> setNameFn) {
3577  setNameFn(getResult(), "splat");
3578 }
3579 
3581  if (getType().getNumDynamicDims() !=
3582  static_cast<int64_t>(getDynamicSizes().size()))
3583  return emitOpError("incorrect number of dynamic sizes, has ")
3584  << getDynamicSizes().size() << ", expected "
3585  << getType().getNumDynamicDims();
3586  return success();
3587 }
3588 
3591  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3592  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3593  unsigned ctr = 0;
3594  for (int64_t i = 0; i < getType().getRank(); ++i) {
3595  if (getType().isDynamicDim(i)) {
3596  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3597  } else {
3598  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3599  }
3600  }
3601  return success();
3602 }
3603 
3604 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3605  auto constOperand = adaptor.getInput();
3606  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3607  return {};
3608 
3609  // Do not fold if the splat is not statically shaped
3610  if (!getType().hasStaticShape())
3611  return {};
3612 
3613  // SplatElementsAttr::get treats single value for second arg as being a
3614  // splat.
3615  return SplatElementsAttr::get(getType(), {constOperand});
3616 }
3617 
3618 //===----------------------------------------------------------------------===//
3619 // PackOp/UnPackOp Common
3620 //===----------------------------------------------------------------------===//
3621 
3622 template <typename OpTy>
3623 static LogicalResult
3625  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3626  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3627  "applies to only pack or unpack operations");
3628  int64_t destRank = op.getDestRank();
3629  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3630  reifiedReturnShapes[0] =
3631  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3632  return success();
3633 }
3634 
3635 template <typename OpTy>
3637  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3638  "applies to only pack or unpack operations");
3639  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3640  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3641  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3642  assert(tiles.size() == dimsToTile.size() &&
3643  "tiles must match indices of dimension to block");
3644  // bind the dimension `i` with the tile factor.
3645  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3646  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3647  return dimAndTileMapping;
3648 }
3649 
3650 template <typename OpTy>
3652  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3653  "applies to only pack or unpack operations");
3654  Builder builder(op);
3655  SmallVector<OpFoldResult> mixedInnerTiles;
3656  unsigned dynamicValIndex = 0;
3657  for (int64_t staticTile : op.getStaticInnerTiles()) {
3658  if (!ShapedType::isDynamic(staticTile))
3659  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3660  else
3661  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3662  }
3663  return mixedInnerTiles;
3664 }
3665 
3666 template <typename OpTy>
3668  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3669  "applies to only pack or unpack operations");
3670  SmallVector<Value> dynamicTiles;
3671  SmallVector<int64_t> staticTiles;
3672  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3673  return staticTiles;
3674 }
3675 
3676 /// Returns true if `dimsPos` is invalid. It is invalid when:
3677 /// a) It contains duplicate.
3678 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3679 /// c) The number of elements in `dimsPos` is > than `rank`.
3681  size_t rank) {
3682  size_t dimsPosSize = dimsPos.size();
3683  if (dimsPosSize > rank)
3684  return true;
3685  DenseSet<int64_t> uniqued;
3686  for (int64_t dim : dimsPos)
3687  uniqued.insert(dim);
3688  if (dimsPosSize != uniqued.size())
3689  return true;
3690  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3691  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3692  });
3693 }
3694 
3695 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3696 /// of the `limitShape`.
3697 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3698  ArrayRef<int64_t> limitShape) {
3699  assert(
3700  sourceShape.size() == limitShape.size() &&
3701  "expected source shape rank, and limit of the shape to have same rank");
3702  return llvm::all_of(
3703  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3704  int64_t sourceExtent = std::get<0>(it);
3705  int64_t limit = std::get<1>(it);
3706  return ShapedType::isDynamic(sourceExtent) ||
3707  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3708  });
3709 }
3710 
3711 template <typename OpTy>
3713  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3714  "applies to only pack or unpack operations");
3715  Operation *op = packOrUnPack.getOperation();
3716 
3717  // Return true if we have a zero-value tile.
3718  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3719  return llvm::any_of(
3720  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3721  };
3722 
3723  // Verify tiles. Do not allow zero tiles.
3724  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3725  if (hasZeros(mixedTiles))
3726  return op->emitError("invalid zero tile factor");
3727 
3728  // Verify inner_dims_pos and outer_dims_perm.
3729  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3730  ? packOrUnPack.getSourceType()
3731  : packOrUnPack.getDestType();
3732  size_t unpackedRank = unpackedType.getRank();
3733  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3734  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3735  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3736  return op->emitError("invalid inner_dims_pos vector");
3737  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3738  return op->emitError("invalid outer_dims_perm vector");
3739  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3740  return op->emitError("outer_dims_perm must be a permutation or empty");
3741 
3742  // Tiling factors must be less than or equal to the input rank for pack (or
3743  // output rank for unpack), and must match the number of `inner_dims_pos`.
3744  if (mixedTiles.size() > unpackedRank) {
3745  return op->emitError("tiling factors must be less than or equal to the "
3746  "input rank for pack or output rank for unpack");
3747  }
3748  if (mixedTiles.size() != innerDimsPos.size()) {
3749  return op->emitError(
3750  "tiling factors must equal the number of dimensions to tile");
3751  }
3752 
3753  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3754  ? packOrUnPack.getDestType()
3755  : packOrUnPack.getSourceType();
3756  size_t packedRank = packedType.getRank();
3757  // Require output rank to match input rank + number of blocking factors.
3758  if (unpackedRank + mixedTiles.size() != packedRank) {
3759  return op->emitError(
3760  "packed rank must equal unpacked rank + tiling factors");
3761  }
3762 
3763  // Verify result shape is greater than the minimum expected
3764  // by the pack operation, and that the output shape
3765  // represents full tiles.
3766  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3767  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3768  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3769  return op->emitError("the shape of output is not large enough to hold the "
3770  "packed data. Expected at least ")
3771  << expectedPackedType << ", got " << packedType;
3772  }
3773  if (!llvm::all_of(
3774  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3775  mixedTiles),
3776  [](std::tuple<int64_t, OpFoldResult> it) {
3777  std::optional<int64_t> constTileSize =
3778  getConstantIntValue(std::get<1>(it));
3779  int64_t shape = std::get<0>(it);
3780  if (!constTileSize) {
3781  // If specified tile size is dynamic, output shape should
3782  // be dynamic too.
3783  return ShapedType::isDynamic(shape);
3784  }
3785  if (ShapedType::isDynamic(shape)) {
3786  // For the shape being dynamic when tile size is
3787  // specified, return true. In canonical form a constant
3788  // tile size should lead to constant shape of the tiled
3789  // dimension, but not needed for verification.
3790  return true;
3791  }
3792  return shape == constTileSize.value();
3793  })) {
3794  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3795  "tiled dimension in the packed type");
3796  }
3797  return success();
3798 }
3799 
3800 namespace {
3801 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3802 /// various permutations to the op.
3803 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3804 // these. These may or may not become true foldings / canonicalizations
3805 // depending on how aggressive we want to be in automatically folding
3806 // transposes.
3807 struct PackOrUnPackTransposeResult {
3808  SmallVector<int64_t> innerDimsPos;
3809  SmallVector<OpFoldResult> innerTiles;
3810  SmallVector<int64_t> outerDimsPerm;
3811 };
3812 } // namespace
3813 
3814 template <typename OpTy>
3815 static PackOrUnPackTransposeResult
3817  ArrayRef<int64_t> innerPermutation,
3818  ArrayRef<int64_t> outerPermutation) {
3819  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3820  "applies to only pack or unpack operations");
3821  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3822  "some permutation must be non-empty");
3823  PackOrUnPackTransposeResult metadata;
3824  metadata.innerDimsPos =
3825  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3826  metadata.innerTiles =
3827  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3828  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3829  ? packOrUnPackOp.getSourceRank()
3830  : packOrUnPackOp.getDestRank();
3831  metadata.outerDimsPerm =
3832  packOrUnPackOp.getOuterDimsPerm().empty()
3833  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3834  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3835  if (!innerPermutation.empty()) {
3836  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3837  isPermutationVector(innerPermutation) &&
3838  "invalid inner permutation");
3839  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3840  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3841  }
3842  if (!outerPermutation.empty()) {
3843  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3844  isPermutationVector(outerPermutation) &&
3845  "invalid outer permutation");
3846  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3847  }
3848  return metadata;
3849 }
3850 
3851 //===----------------------------------------------------------------------===//
3852 // PackOp
3853 //===----------------------------------------------------------------------===//
3854 
3855 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3856  setNameFn(getResult(), "pack");
3857 }
3858 
3859 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3860  Value dest, ArrayRef<int64_t> innerDimsPos,
3861  ArrayRef<OpFoldResult> innerTiles,
3862  std::optional<Value> paddingValue,
3863  ArrayRef<int64_t> outerDimsPerm) {
3864  assert(innerDimsPos.size() == innerTiles.size() &&
3865  "number of tile sizes specified must match the specified number of "
3866  "original dimensions to be tiled");
3867  SmallVector<int64_t> staticTileSizes;
3868  SmallVector<Value> dynamicTileSizes;
3869  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3870  build(builder, state, dest.getType(), source, dest,
3871  paddingValue ? *paddingValue : nullptr,
3872  outerDimsPerm.empty() ? nullptr
3873  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3874  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3875  builder.getDenseI64ArrayAttr(staticTileSizes));
3876 }
3877 
3880  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3881  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3882 }
3883 
3884 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3885  return getDimAndTileMappingImpl(*this);
3886 }
3887 
3888 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3889  return getMixedTilesImpl(*this);
3890 }
3891 
3892 SmallVector<int64_t> PackOp::getStaticTiles() {
3893  return getStaticTilesImpl(*this);
3894 }
3895 
3896 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3897  ArrayRef<int64_t> innerDimsPos,
3898  ArrayRef<int64_t> outputShape,
3899  ArrayRef<int64_t> outerDimsPerm,
3900  ArrayRef<OpFoldResult> innerTiles) {
3901  SmallVector<int64_t> outputTileSizes(
3902  outputShape.take_front(inputShape.size()));
3903  if (!outerDimsPerm.empty()) {
3904  assert(outerDimsPerm.size() == outputTileSizes.size() &&
3905  "expected output and outer_dims_perm to have same size");
3906  applyPermutationToVector(outputTileSizes,
3907  invertPermutationVector(outerDimsPerm));
3908  }
3909  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3910  if (ShapedType::isDynamic(inputShape[pos]))
3911  continue;
3912  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3913 
3914  if (!constantTile) {
3915  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3916  (inputShape[pos] % outputTileSizes[pos] != 0))
3917  return true;
3918  } else if (inputShape[pos] % (*constantTile) != 0) {
3919  return true;
3920  }
3921  }
3922  return false;
3923 }
3924 
3927  return failure();
3928 
3929  // Verify padding value, and bail out if the tile does not divide the
3930  // dimension fully. In the case of dynamic tile factors or dimensions, having
3931  // a partial tile is undefined behavior.
3932  auto paddingValue = getPaddingValue();
3933  if (paddingValue &&
3934  paddingValue.getType() != getSourceType().getElementType()) {
3935  return emitOpError("expected padding_value has ")
3936  << getSourceType().getElementType()
3937  << " but got: " << paddingValue.getType();
3938  }
3939 
3940  if (!paddingValue &&
3941  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3942  getDestType().getShape(), getOuterDimsPerm(),
3943  getMixedTiles())) {
3944  return emitOpError(
3945  "invalid tile factor or output size provided. Only full tiles are "
3946  "supported when padding_value is not set");
3947  }
3948  return success();
3949 }
3950 
3951 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3952 /// Value's to kDynamic, even if they are arith.constant values.
3953 static SmallVector<int64_t>
3955  SmallVector<int64_t> result;
3956  for (auto o : ofrs) {
3957  // Have to do this first, as getConstantIntValue special-cases constants.
3958  if (llvm::dyn_cast_if_present<Value>(o))
3959  result.push_back(ShapedType::kDynamic);
3960  else
3961  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3962  }
3963  return result;
3964 }
3965 
3966 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3967 /// the packed type. Having a shared helper helps implement these two methods in
3968 /// a way that ensures that they agree on which dimensions are dynamic.
3970  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3971  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3972  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3973  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
3974  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3975  continue;
3976  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3977  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3978  continue;
3979  }
3980  resultShape[tiledDim.value()] = divideCeilSigned(
3981  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
3982  }
3983 
3984  // Swap tile loops if outer_dims_perm is available.
3985  if (!outerDimsPerm.empty())
3986  applyPermutationToVector(resultShape, outerDimsPerm);
3987 
3988  // Append the inner tile dimensions.
3989  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3990  return resultShape;
3991 }
3992 
3993 SmallVector<OpFoldResult> PackOp::getResultShape(
3994  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
3995  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
3996  ArrayRef<int64_t> outerDimsPerm) {
3997  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
3998 
3999  AffineExpr s0, s1;
4000  bindSymbols(builder.getContext(), s0, s1);
4001  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4002  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4003  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4004  builder, loc, ceilDivExpr,
4005  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4006  }
4007  if (!outerDimsPerm.empty())
4008  applyPermutationToVector(resultDims, outerDimsPerm);
4009  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4010 
4011  SmallVector<int64_t> resultTypeShape =
4013  asShapeWithAnyValueAsDynamic(innerTileSizes),
4014  innerDimsPos, outerDimsPerm);
4015 
4016  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4017  // result type shape says it's a dynamic dim. This is needed as callers may
4018  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4019  // dynamic dims returned by that.
4020  for (unsigned i = 0; i < resultDims.size(); ++i) {
4021  if (!ShapedType::isDynamic(resultTypeShape[i]))
4022  continue;
4023  resultDims[i] =
4024  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4025  }
4026 
4027  return resultDims;
4028 }
4029 
4030 /// Get the expected packed type based on source type, tile factors, position of
4031 /// the inner tiles and permutation of the outer tiled loop.
4032 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4033  ArrayRef<int64_t> innerTileSizes,
4034  ArrayRef<int64_t> innerDimsPos,
4035  ArrayRef<int64_t> outerDimsPerm) {
4037  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4038  return RankedTensorType::get(resultShape, sourceType.getElementType());
4039 }
4040 
4041 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4042  ArrayRef<OpFoldResult> innerTileSizes,
4043  ArrayRef<int64_t> innerDimsPos,
4044  ArrayRef<int64_t> outerDimsPerm) {
4045  AffineExpr dim0, dim1;
4046  bindDims(b.getContext(), dim0, dim1);
4047  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4048  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4049  {v1, v2});
4050  };
4051 
4052  SmallVector<OpFoldResult> mixedSizes;
4053  for (auto [index, value] : llvm::enumerate(
4054  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4055  if (ShapedType::isDynamic(value))
4056  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4057  else
4058  mixedSizes.push_back(b.getIndexAttr(value));
4059  }
4060  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4061  int64_t dimPos = std::get<0>(it);
4062  OpFoldResult tileSize = std::get<1>(it);
4063  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4064  }
4065  if (!outerDimsPerm.empty())
4066  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4067 
4068  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4069  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4070  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4071 }
4072 
4073 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4074  ArrayRef<int64_t> innerPermutation,
4075  ArrayRef<int64_t> outerPermutation) {
4076  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4077  *this, innerPermutation, outerPermutation);
4078  Value transposedDest =
4079  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4080  metadata.innerDimsPos, metadata.outerDimsPerm);
4081  return b.create<PackOp>(loc, getSource(), transposedDest,
4082  metadata.innerDimsPos, metadata.innerTiles,
4083  getPaddingValue(), metadata.outerDimsPerm);
4084 }
4085 
4086 /// Returns true if the tiles and the tiled dims are constant.
4087 template <typename OpTy>
4089  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4090  "applies to only pack or unpack operations");
4091  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4092  ? op.getDestType()
4093  : op.getSourceType();
4094  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4095  for (auto [dimDest, tile] : llvm::zip(
4096  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4097  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4098  if (!constTileSize || ShapedType::isDynamic(dimDest))
4099  return false;
4100  }
4101  return true;
4102 }
4103 
4104 Speculation::Speculatability PackOp::getSpeculatability() {
4105  if (getPaddingValue())
4107 
4108  // The verifier rejects already operations if we can statically prove that the
4109  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4110  // to have constant tiles and tiled inner dimensions.
4111  if (!areTilesAndTiledDimsAllConstant(*this))
4113 
4115 }
4116 
4117 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4118 // dimensions for pack and unpack.
4119 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4120  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4121  return false;
4122  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4123  return true;
4124  // Outer dims permutation is optional.
4125  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4126  // identity permutation.
4127  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4128  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4129 }
4130 
4131 // Return true if pack and unpack have the same tiles.
4132 // Same SSA values or same integer constants.
4133 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4134  auto packTiles = packOp.getMixedTiles();
4135  auto unPackTiles = unPackOp.getMixedTiles();
4136  if (packTiles.size() != unPackTiles.size())
4137  return false;
4138  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4139  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4140  return false;
4141  }
4142  return true;
4143 }
4144 
4145 /// Returns true if the pack op does not need a padding value.
4146 static bool paddingIsNotNeeded(PackOp op) {
4147  auto srcType = op.getSourceType();
4148  if (llvm::any_of(op.getInnerDimsPos(),
4149  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4150  return false;
4151  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4152  return false;
4153  return !PackOp::requirePaddingValue(
4154  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4155  op.getOuterDimsPerm(), op.getMixedTiles());
4156 }
4157 
4158 /// Returns true if the `srcShape` or `destShape` is different from the one in
4159 /// `packOp` and populates each with the inferred static shape.
4160 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4161  SmallVectorImpl<int64_t> &destShape) {
4162  bool changeNeeded = false;
4163  srcShape.assign(packOp.getSourceType().getShape().begin(),
4164  packOp.getSourceType().getShape().end());
4165  destShape.assign(packOp.getDestType().getShape().begin(),
4166  packOp.getDestType().getShape().end());
4167  llvm::SmallSetVector<int64_t, 4> innerDims;
4168  innerDims.insert(packOp.getInnerDimsPos().begin(),
4169  packOp.getInnerDimsPos().end());
4170  SmallVector<int64_t> inverseOuterDimsPerm;
4171  if (!packOp.getOuterDimsPerm().empty())
4172  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4173  int srcRank = packOp.getSourceRank();
4174  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4175  if (innerDims.contains(i))
4176  continue;
4177  int64_t srcPos = i;
4178  int64_t destPos = i;
4179  if (!inverseOuterDimsPerm.empty())
4180  destPos = inverseOuterDimsPerm[srcPos];
4181  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4182  ShapedType::isDynamic(destShape[destPos])) {
4183  continue;
4184  }
4185  int64_t size = srcShape[srcPos];
4186  if (ShapedType::isDynamic(size))
4187  size = destShape[destPos];
4188  srcShape[srcPos] = size;
4189  destShape[destPos] = size;
4190  changeNeeded = true;
4191  }
4192  return changeNeeded;
4193 }
4194 
4195 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4196  // Fold an unpack(pack(x)) to x.
4197  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4198  if (unPackOp.getSourceType() != packOp.getDestType())
4199  return failure();
4200  if (packOp.getPaddingValue() ||
4201  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4202  !haveSameTiles(packOp, unPackOp))
4203  return failure();
4204  rewriter.replaceOp(packOp, unPackOp.getSource());
4205  return success();
4206  }
4207 
4208  // Fold optional PaddingValue operand away if padding is not needed.
4209  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4210  rewriter.startOpModification(packOp);
4211  packOp.getPaddingValueMutable().clear();
4212  rewriter.finalizeOpModification(packOp);
4213  return success();
4214  }
4215 
4216  // Insert tensor.cast ops if static shape inference is available..
4217  SmallVector<int64_t> srcShape, destShape;
4218  if (inferStaticShape(packOp, srcShape, destShape)) {
4219  Location loc = packOp.getLoc();
4220  Value source = packOp.getSource();
4221  if (srcShape != packOp.getSourceType().getShape()) {
4222  auto newSrcType = packOp.getSourceType().clone(srcShape);
4223  source =
4224  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4225  }
4226  Value dest = packOp.getDest();
4227  if (destShape != packOp.getDestType().getShape()) {
4228  auto newDestType = packOp.getDestType().clone(destShape);
4229  dest =
4230  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4231  }
4232  Value newOp = rewriter.create<tensor::PackOp>(
4233  loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4234  packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4235  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4236  packOp, packOp.getResult().getType(), newOp);
4237  return success();
4238  }
4239 
4240  return failure();
4241 }
4242 
4243 template <typename PackOrUnpackOp>
4244 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4245  RankedTensorType packedTensorType) {
4246  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4247  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4248  "Function meant for pack/unpack");
4249  // This is a pad if packing only adds ones and we don't transpose dimensions.
4250 
4251  // Check that we are not transposing any dimensions.
4252  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4253  int64_t numPackedDims = innerDimsPos.size();
4254  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4255  if (orderedDims != innerDimsPos) {
4256  // Dimensions don't happen in order.
4257  return false;
4258  }
4259 
4260  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4261  int64_t packedRank = packedTensorType.getRank();
4262  // At this point we know that we are taking numPackedDims outer
4263  // dimensions and pushing them all the way as the inner most dimensions.
4264  // What's left on the outer most dimensions is, in this order:
4265  // - the factor of the packed dimensions, then
4266  // - the untouched dimensions
4267  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4268  // if all the dimensions that bubble outerward are ones.
4269  // Therefore check that all the dimensions but the numPackedDims inner most
4270  // ones are ones.
4271  return llvm::all_of(
4272  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4273  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4274 }
4275 
4276 bool PackOp::isLikePad() {
4277  auto packedTensorType =
4278  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4279  return isLikePadUnPad(*this, packedTensorType);
4280 }
4281 
4282 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4283  std::optional<Attribute> paddingValue;
4284  if (auto pad = adaptor.getPaddingValue())
4285  paddingValue = pad;
4286  if (OpFoldResult reshapedSource = reshapeConstantSource(
4287  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4288  getDestType(), paddingValue))
4289  return reshapedSource;
4290  return {};
4291 }
4292 
4293 //===----------------------------------------------------------------------===//
4294 // UnPackOp
4295 //===----------------------------------------------------------------------===//
4296 
4297 void UnPackOp::getAsmResultNames(
4298  function_ref<void(Value, StringRef)> setNameFn) {
4299  setNameFn(getResult(), "unpack");
4300 }
4301 
4304  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4305  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4306 }
4307 
4308 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4309  return getDimAndTileMappingImpl(*this);
4310 }
4311 
4312 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4313  return getMixedTilesImpl(*this);
4314 }
4315 
4316 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4317  return getStaticTilesImpl(*this);
4318 }
4319 
4321  return commonVerifierPackAndUnPackOp(*this);
4322 }
4323 
4324 Speculation::Speculatability UnPackOp::getSpeculatability() {
4325  // See PackOp::getSpeculatability.
4326  if (!areTilesAndTiledDimsAllConstant(*this))
4328 
4330 }
4331 
4332 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4333  Value dest, ArrayRef<int64_t> innerDimsPos,
4334  ArrayRef<OpFoldResult> innerTiles,
4335  ArrayRef<int64_t> outerDimsPerm) {
4336  assert(innerDimsPos.size() == innerTiles.size() &&
4337  "number of tile sizes specified must match the specified number of "
4338  "original dimensions to be tiled");
4339  SmallVector<int64_t> staticTileSizes;
4340  SmallVector<Value> dynamicTileSizes;
4341  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4342  build(builder, state, dest.getType(), source, dest,
4343  outerDimsPerm.empty() ? nullptr
4344  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4345  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4346  builder.getDenseI64ArrayAttr(staticTileSizes));
4347 }
4348 
4349 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4350  Value source,
4351  ArrayRef<OpFoldResult> innerTileSizes,
4352  ArrayRef<int64_t> innerDimsPos,
4353  ArrayRef<int64_t> outerDimsPerm) {
4354  AffineExpr sym0, sym1;
4355  bindSymbols(b.getContext(), sym0, sym1);
4356  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4357  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4358  };
4359 
4360  SmallVector<OpFoldResult> mixedSizes;
4361  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4362  for (auto i :
4363  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4364  if (srcType.isDynamicDim(i))
4365  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4366  else
4367  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4368  }
4369  if (!outerDimsPerm.empty()) {
4370  applyPermutationToVector<OpFoldResult>(
4371  mixedSizes, invertPermutationVector(outerDimsPerm));
4372  }
4373 
4374  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4375  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4376 
4377  auto elemType = srcType.getElementType();
4378  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4379 }
4380 
4381 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4382  Value transposedSource,
4383  ArrayRef<int64_t> innerPermutation,
4384  ArrayRef<int64_t> outerPermutation) {
4385  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4386  *this, innerPermutation, outerPermutation);
4387  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4388  metadata.innerDimsPos, metadata.innerTiles,
4389  metadata.outerDimsPerm);
4390 }
4391 
4392 /// Returns true if the `srcShape` or `destShape` is different from the one in
4393 /// `op` and populates each with the inferred static shape.
4394 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4395  SmallVectorImpl<int64_t> &destShape) {
4396  bool changeNeeded = false;
4397  srcShape.assign(op.getSourceType().getShape().begin(),
4398  op.getSourceType().getShape().end());
4399  destShape.assign(op.getDestType().getShape().begin(),
4400  op.getDestType().getShape().end());
4401  llvm::SmallSetVector<int64_t, 4> innerDims;
4402  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4403  SmallVector<int64_t> inverseOuterDimsPerm;
4404  if (!op.getOuterDimsPerm().empty())
4405  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4406  int destRank = op.getDestRank();
4407  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4408  if (innerDims.contains(i))
4409  continue;
4410  int64_t srcPos = i;
4411  int64_t destPos = i;
4412  if (!inverseOuterDimsPerm.empty())
4413  srcPos = inverseOuterDimsPerm[destPos];
4414  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4415  ShapedType::isDynamic(destShape[destPos])) {
4416  continue;
4417  }
4418  int64_t size = srcShape[srcPos];
4419  if (ShapedType::isDynamic(size))
4420  size = destShape[destPos];
4421  srcShape[srcPos] = size;
4422  destShape[destPos] = size;
4423  changeNeeded = true;
4424  }
4425  return changeNeeded;
4426 }
4427 
4428 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4429  PatternRewriter &rewriter) {
4430  /// pack(unpack(x)) -> x
4431  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4432  if (packOp.getDestType() != unPackOp.getSourceType())
4433  return failure();
4434  if (packOp.getPaddingValue() ||
4435  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4436  !haveSameTiles(packOp, unPackOp))
4437  return failure();
4438  rewriter.replaceOp(unPackOp, packOp.getSource());
4439  return success();
4440  }
4441  /// unpack(destinationStyleOp(x)) -> unpack(x)
4442  if (auto dstStyleOp =
4443  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4444  auto destValue = cast<OpResult>(unPackOp.getDest());
4445  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4446  rewriter.modifyOpInPlace(unPackOp,
4447  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4448  return success();
4449  }
4450 
4451  // Insert tensor.cast ops if static shape inference is available..
4452  SmallVector<int64_t> srcShape, destShape;
4453  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4454  Location loc = unPackOp.getLoc();
4455  Value source = unPackOp.getSource();
4456  if (srcShape != unPackOp.getSourceType().getShape()) {
4457  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4458  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4459  unPackOp.getSource());
4460  }
4461  Value dest = unPackOp.getDest();
4462  if (destShape != unPackOp.getDestType().getShape()) {
4463  auto newDestType = unPackOp.getDestType().clone(destShape);
4464  dest =
4465  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4466  }
4467  Value newOp = rewriter.create<tensor::UnPackOp>(
4468  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4469  unPackOp.getOuterDimsPerm());
4470  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4471  unPackOp, unPackOp.getResult().getType(), newOp);
4472  return success();
4473  }
4474 
4475  return failure();
4476 }
4477 
4478 bool UnPackOp::isLikeUnPad() {
4479  RankedTensorType packedTensorType = getSourceType();
4480  return isLikePadUnPad(*this, packedTensorType);
4481 }
4482 
4483 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4484  if (OpFoldResult reshapedSource = reshapeConstantSource(
4485  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4486  getResult().getType()))
4487  return reshapedSource;
4488  return {};
4489 }
4490 
4491 //===----------------------------------------------------------------------===//
4492 // Common Canonicalizers and Folders.
4493 //===----------------------------------------------------------------------===//
4494 
4495 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4496 /// the `tensor.cast` has source that is more static than the consuming op.
4497 ///
4498 /// Example:
4499 /// ```mlir
4500 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4501 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4502 /// ```
4503 ///
4504 /// folds into:
4505 ///
4506 /// ```mlir
4507 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4508 /// ```
4509 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4510 /// can add the pattern to their canonicalizers.
4512  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4514  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4515 
4516  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4517  PatternRewriter &rewriter) const override {
4518  // InsertSliceOp has its own logic about folding tensor.cast ops.
4519  if (isa<InsertSliceOp>(op.getOperation()))
4520  return failure();
4521 
4522  // Exclude DPS ops that are also LoopLike from this interface as they
4523  // might need special handling of attached regions.
4524  if (isa<LoopLikeOpInterface>(op.getOperation()))
4525  return failure();
4526 
4527  // If no operand comes from a tensor::CastOp and can be folded then fail.
4528  bool hasTensorCastOperand =
4529  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4530  if (llvm::isa<BlockArgument>(opOperand.get()))
4531  return false;
4532  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4533  return castOp && canFoldIntoConsumerOp(castOp);
4534  });
4535  if (!hasTensorCastOperand)
4536  return failure();
4537 
4538  SmallVector<Type, 4> newResultTypes(op->getResultTypes());
4539  SmallVector<Value, 4> newOperands;
4540  newOperands.reserve(op->getNumOperands());
4541  // Assumes that the result has dpsInits followed by nonDpsInits.
4542  int64_t dpsInitIdx = 0;
4543  for (OpOperand &opOperand : op->getOpOperands()) {
4544  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4545  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4546  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4547  if (op.isDpsInit(&opOperand) &&
4548  !llvm::isa<MemRefType>(newOperands.back().getType()))
4549  newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4550  }
4551 
4552  // Clone op.
4553  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4554  SmallVector<Value, 4> replacements;
4555  replacements.reserve(newOp->getNumResults());
4556  for (auto [oldResult, newResult] :
4557  llvm::zip(op->getResults(), newOp->getResults())) {
4558  if (newResult.getType() != oldResult.getType()) {
4559  replacements.push_back(rewriter.create<tensor::CastOp>(
4560  op->getLoc(), oldResult.getType(), newResult));
4561  } else {
4562  replacements.push_back(newResult);
4563  }
4564  }
4565  rewriter.replaceOp(op, replacements);
4566 
4567  return success();
4568  }
4569 };
4570 
4571 //===----------------------------------------------------------------------===//
4572 // TensorDialect
4573 //===----------------------------------------------------------------------===//
4574 
4575 void TensorDialect::getCanonicalizationPatterns(
4576  RewritePatternSet &results) const {
4578 }
4579 
4580 //===----------------------------------------------------------------------===//
4581 // TableGen'd op method definitions
4582 //===----------------------------------------------------------------------===//
4583 
4584 #define GET_OP_CLASSES
4585 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:4088
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:387
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1292
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3816
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2155
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3636
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3667
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:4146
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2882
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3969
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3954
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3651
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2576
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2428
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2450
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4160
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3697
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4133
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2537
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4244
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3712
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:179
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3624
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3680
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1078
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2878
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:135
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4119
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2596
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1757
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:932
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:249
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:260
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:358
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2388
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:349
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:319
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2855
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2475
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:123
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:74
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:267
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:375
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:991
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4512
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4516
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2407
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2408
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2395
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2396
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)