MLIR  20.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <algorithm>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::tensor;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 /// Materialize a single constant operation from a given attribute value with
43 /// the desired resultant type.
45  Attribute value, Type type,
46  Location loc) {
47  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
48  return op;
49  if (complex::ConstantOp::isBuildableWith(value, type))
50  return builder.create<complex::ConstantOp>(loc, type,
51  llvm::cast<ArrayAttr>(value));
52  return nullptr;
53 }
54 
56  int64_t dim) {
57  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
59  if (tensorType.isDynamicDim(dim))
60  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
61 
62  return builder.getIndexAttr(tensorType.getDimSize(dim));
63 }
64 
66  Location loc, Value value) {
67  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
69  for (int64_t i = 0; i < tensorType.getRank(); ++i)
70  result.push_back(getMixedSize(builder, loc, value, i));
71  return result;
72 }
73 
75  OpResult opResult) {
76  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
77  assert(tensorType && "expected tensor type");
78 
79  // If the op has a destination, it implements DestinationStyleOpInterface and
80  // we can query the destination operand from that interface.
81  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
82  if (destOp)
83  return destOp.getTiedOpOperand(opResult)->get();
84 
85  // Otherwise, create a new destination tensor with the same shape.
87  b.setInsertionPoint(opResult.getDefiningOp());
88 
89  // Compute sizes.
90  SmallVector<OpFoldResult> mixedSizes;
91  if (!tensorType.hasStaticShape()) {
92  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
93  ReifiedRankedShapedTypeDims reifiedShapes;
94  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
95  return failure();
96  mixedSizes = reifiedShapes[opResult.getResultNumber()];
97  } else {
98  // Static shape: Take static sizes directly.
99  for (int64_t sz : tensorType.getShape())
100  mixedSizes.push_back(b.getIndexAttr(sz));
101  }
102 
103  // Create empty tensor.
104  Value emptyTensor =
105  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
106  return emptyTensor;
107 }
108 
110  Operation *op,
111  SmallVector<Value> &result) {
112  for (OpResult opResult : op->getResults()) {
113  if (llvm::isa<TensorType>(opResult.getType())) {
114  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
115  if (failed(destination))
116  return failure();
117  result.push_back(*destination);
118  }
119  }
120  return success();
121 }
122 
124  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
125  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
126  return rtp1.getShape() == rtp2.getShape() &&
127  rtp1.getElementType() == rtp2.getElementType();
128  return false;
129  }
130  return tp1 == tp2; // default implementation
131 }
132 
133 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
134 /// rank-extending tensor.insert_slice op.
135 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
136  ArrayRef<OpFoldResult> mixedSizes) {
137  llvm::SmallBitVector droppedDims(mixedSizes.size());
138  int64_t shapePos = reducedShape.size() - 1;
139 
140  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
141  size_t idx = mixedSizes.size() - size.index() - 1;
142  // Rank-reduced dims must have a static unit dimension.
143  bool isStaticUnitSize =
144  size.value().is<Attribute>() &&
145  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
146 
147  if (shapePos < 0) {
148  // There are no more dims in the reduced shape. All remaining sizes must
149  // be rank-reduced dims.
150  assert(isStaticUnitSize && "expected unit dim");
151  droppedDims.set(idx);
152  continue;
153  }
154 
155  // Dim is preserved if the size is not a static 1.
156  if (!isStaticUnitSize) {
157  --shapePos;
158  continue;
159  }
160 
161  // Dim is preserved if the reduced shape dim is also 1.
162  if (reducedShape[shapePos] == 1) {
163  --shapePos;
164  continue;
165  }
166 
167  // Otherwise: Dim is dropped.
168  droppedDims.set(idx);
169  }
170 
171  assert(shapePos < 0 && "dimension mismatch");
172  return droppedDims;
173 }
174 
175 /// Given a ranked tensor type and a range of values that defines its dynamic
176 /// dimension sizes, turn all dynamic sizes that have a constant value into
177 /// static dimension sizes.
178 static RankedTensorType
179 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
180  SmallVector<Value> &foldedDynamicSizes) {
181  SmallVector<int64_t> staticShape(type.getShape().begin(),
182  type.getShape().end());
183  assert(type.getNumDynamicDims() ==
184  static_cast<int64_t>(dynamicSizes.size()) &&
185  "incorrect number of dynamic sizes");
186 
187  // Compute new static and dynamic sizes.
188  unsigned ctr = 0;
189  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
190  if (type.isDynamicDim(i)) {
191  Value dynamicSize = dynamicSizes[ctr++];
192  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
193  if (cst.has_value()) {
194  // Dynamic size must be non-negative.
195  if (cst.value() < 0) {
196  foldedDynamicSizes.push_back(dynamicSize);
197  continue;
198  }
199  staticShape[i] = *cst;
200  } else {
201  foldedDynamicSizes.push_back(dynamicSize);
202  }
203  }
204  }
205 
206  return RankedTensorType::get(staticShape, type.getElementType(),
207  type.getEncoding());
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // BitcastOp
212 //===----------------------------------------------------------------------===//
213 
214 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
215  if (inputs.size() != 1 || outputs.size() != 1)
216  return false;
217  Type a = inputs.front(), b = outputs.front();
218  auto aT = dyn_cast<TensorType>(a);
219  auto bT = dyn_cast<TensorType>(b);
220  if (!aT || !bT)
221  return false;
222 
223  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
224  return false;
225 
226  return succeeded(verifyCompatibleShape(aT, bT));
227 }
228 
229 namespace {
230 
231 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
232 /// operation.
233 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
235 
236  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237  PatternRewriter &rewriter) const final {
238  auto tensorBitcastOperand =
239  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
240  if (!tensorBitcastOperand)
241  return failure();
242 
243  auto resultType = cast<TensorType>(tensorBitcast.getType());
244  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
245  tensorBitcastOperand.getOperand());
246  return success();
247  }
248 };
249 
250 } // namespace
251 
252 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
253  MLIRContext *context) {
254  results.add<ChainedTensorBitcast>(context);
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // CastOp
259 //===----------------------------------------------------------------------===//
260 
261 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
262  setNameFn(getResult(), "cast");
263 }
264 
265 /// Returns true if `target` is a ranked tensor type that preserves static
266 /// information available in the `source` ranked tensor type.
268  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
269  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
270 
271  // Requires RankedTensorType.
272  if (!sourceType || !targetType)
273  return false;
274 
275  // Requires same elemental type.
276  if (sourceType.getElementType() != targetType.getElementType())
277  return false;
278 
279  // Requires same rank.
280  if (sourceType.getRank() != targetType.getRank())
281  return false;
282 
283  // Requires same encoding.
284  if (sourceType.getEncoding() != targetType.getEncoding())
285  return false;
286 
287  // If cast is towards more static sizes along any dimension, don't fold.
288  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
289  if (!ShapedType::isDynamic(std::get<0>(t)) &&
290  ShapedType::isDynamic(std::get<1>(t)))
291  return false;
292  }
293 
294  return true;
295 }
296 
297 /// Determines whether tensor::CastOp casts to a more dynamic version of the
298 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
299 /// implement canonicalization patterns for ops in different dialects that may
300 /// consume the results of tensor.cast operations. Such foldable tensor.cast
301 /// operations are typically inserted as `slice` ops and are canonicalized,
302 /// to preserve the type compatibility of their uses.
303 ///
304 /// Returns true when all conditions are met:
305 /// 1. source and result are ranked tensors with same element type and rank.
306 /// 2. the tensor type has more static information than the result
307 ///
308 /// Example:
309 /// ```mlir
310 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
311 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
312 /// ```
313 ///
314 /// folds into:
315 ///
316 /// ```mlir
317 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
318 /// ```
320  if (!castOp)
321  return false;
322 
323  // Can fold if the source of cast has at least as much static information as
324  // its results.
325  return preservesStaticInformation(castOp.getType(),
326  castOp.getSource().getType());
327 }
328 
329 /// Determines whether the tensor::CastOp casts to a more static version of the
330 /// source tensor. This is useful to fold into a producing op and implement
331 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
332 /// being from different dialects. Returns true when all conditions are met:
333 /// 1. source and result and ranked tensors with same element type and rank.
334 /// 2. the result type has more static information than the source.
335 ///
336 /// Example:
337 /// ```mlir
338 /// %1 = producer ... : tensor<?x?xf32>
339 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340 /// ```
341 ///
342 /// can be canonicalized to :
343 ///
344 /// ```mlir
345 /// %2 = producer ... : tensor<8x16xf32>
346 /// ```
347 /// Not all ops might be canonicalizable this way, but for those that can be,
348 /// this method provides a check that it is worth doing the canonicalization.
350  if (!castOp)
351  return false;
352  return preservesStaticInformation(castOp.getSource().getType(),
353  castOp.getType());
354 }
355 
356 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
357 /// that can be folded.
359  bool folded = false;
360  for (OpOperand &operand : op->getOpOperands()) {
361  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
362  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
363  operand.set(castOp.getOperand());
364  folded = true;
365  }
366  }
367  return success(folded);
368 }
369 
370 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
371  if (inputs.size() != 1 || outputs.size() != 1)
372  return false;
373  Type a = inputs.front(), b = outputs.front();
374  auto aT = llvm::dyn_cast<TensorType>(a);
375  auto bT = llvm::dyn_cast<TensorType>(b);
376  if (!aT || !bT)
377  return false;
378 
379  if (aT.getElementType() != bT.getElementType())
380  return false;
381 
382  return succeeded(verifyCompatibleShape(aT, bT));
383 }
384 
385 /// Compute a TensorType that has the joined shape knowledge of the two
386 /// given TensorTypes. The element types need to match.
388  assert(one.getElementType() == two.getElementType());
389 
390  if (!one.hasRank())
391  return two;
392  if (!two.hasRank())
393  return one;
394 
395  int64_t rank = one.getRank();
396  if (rank != two.getRank())
397  return {};
398 
400  join.reserve(rank);
401  for (int64_t i = 0; i < rank; ++i) {
402  if (one.isDynamicDim(i)) {
403  join.push_back(two.getDimSize(i));
404  continue;
405  }
406  if (two.isDynamicDim(i)) {
407  join.push_back(one.getDimSize(i));
408  continue;
409  }
410  if (one.getDimSize(i) != two.getDimSize(i))
411  return {};
412  join.push_back(one.getDimSize(i));
413  }
414  return RankedTensorType::get(join, one.getElementType());
415 }
416 
417 namespace {
418 
419 /// Replaces chains of two tensor.cast operations by a single tensor.cast
420 /// operation if doing so does not remove runtime constraints.
421 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
423 
424  LogicalResult matchAndRewrite(CastOp tensorCast,
425  PatternRewriter &rewriter) const final {
426  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 
428  if (!tensorCastOperand)
429  return failure();
430 
431  auto sourceType =
432  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
433  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
434  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
435 
436  // We can remove the intermediate cast if joining all three produces the
437  // same result as just joining the source and result shapes.
438  auto firstJoin =
439  joinShapes(joinShapes(sourceType, intermediateType), resultType);
440 
441  // The join might not exist if the cast sequence would fail at runtime.
442  if (!firstJoin)
443  return failure();
444 
445  // The newJoin always exists if the above join exists, it might just contain
446  // less information. If so, we cannot drop the intermediate cast, as doing
447  // so would remove runtime checks.
448  auto newJoin = joinShapes(sourceType, resultType);
449  if (firstJoin != newJoin)
450  return failure();
451 
452  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
453  tensorCastOperand.getOperand());
454  return success();
455  }
456 };
457 
458 /// Fold tensor.cast into tesor.extract_slice producer.
459 /// Example:
460 /// ```
461 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
462 /// tensor<128x512xf32> to tensor<?x512xf32>
463 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
464 /// ```
465 /// ->
466 /// ```
467 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
468 /// tensor<128x512xf32> to tensor<16x512xf32>
469 /// ```
470 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
472 
473  LogicalResult matchAndRewrite(CastOp tensorCast,
474  PatternRewriter &rewriter) const final {
475  auto extractOperand =
476  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
477 
478  // Cannot fold cast to unranked tensor.
479  auto rankedResultType =
480  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
481  if (!rankedResultType)
482  return failure();
483 
484  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
485  rankedResultType.getShape() ==
486  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
487  .getShape())
488  return failure();
489 
490  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
491  auto dimMask = computeRankReductionMask(
492  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
493  size_t dimIndex = 0;
494  for (size_t i = 0, e = sizes.size(); i < e; i++) {
495  if (dimMask && dimMask->count(i))
496  continue;
497  int64_t dim = rankedResultType.getShape()[dimIndex++];
498  if (ShapedType::isDynamic(dim))
499  continue;
500  sizes[i] = rewriter.getIndexAttr(dim);
501  }
502 
503  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
504  tensorCast, rankedResultType, extractOperand.getSource(),
505  extractOperand.getMixedOffsets(), sizes,
506  extractOperand.getMixedStrides());
507  return success();
508  }
509 };
510 
511 } // namespace
512 
513 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
514  MLIRContext *context) {
515  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // ConcatOp
520 //===----------------------------------------------------------------------===//
521 
522 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
523  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
524  auto tensorTypes =
525  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
526  return llvm::cast<RankedTensorType>(type);
527  }));
528  int64_t concatRank = tensorTypes[0].getRank();
529 
530  // The concatenation dim must be in the range [0, rank).
531  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
532 
533  SmallVector<int64_t> sizes(concatRank);
534  for (int64_t i = 0, e = concatRank; i < e; ++i) {
535  if (i == dim)
536  continue;
537  SaturatedInteger size;
538  for (auto tensorType : tensorTypes)
539  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
540  sizes[i] = size.asInteger();
541  }
542  auto concatSize = SaturatedInteger::wrap(0);
543  for (auto tensorType : tensorTypes)
544  concatSize =
545  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
546  sizes[dim] = concatSize.asInteger();
547  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
548 }
549 
550 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
551  ValueRange inputs) {
552  FailureOr<RankedTensorType> resultType =
553  inferResultType(dim, inputs.getTypes());
554  assert(succeeded(resultType) && "failed to infer concatenation result type");
555  build(builder, result, *resultType, dim, inputs);
556 }
557 
558 LogicalResult ConcatOp::verify() {
559  if (getInputs().size() < 1)
560  return emitOpError("requires at least one input");
561 
563  for (auto input : getInputs())
564  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 
566  RankedTensorType resultType = getResultType();
567  int64_t resultRank = getRank();
568  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
569  return type.getRank() != resultRank;
570  }))
571  return emitOpError("rank of concatenated inputs must match result rank");
572 
573  Type resultElementType = resultType.getElementType();
574  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
575  return type.getElementType() != resultElementType;
576  }))
577  return emitOpError("inputs and result element type must match");
578 
579  int64_t dim = getDim();
580  if (dim >= resultRank)
581  return emitOpError("concatenation dim must be less than the tensor rank");
582 
583  SmallVector<int64_t> sizes(resultRank);
584  for (int64_t i = 0, e = resultRank; i < e; ++i) {
585  if (i == dim)
586  continue;
587  SaturatedInteger size;
588  for (auto tensorType : inputTypes) {
589  FailureOr<SaturatedInteger> maybeSize =
590  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
591  if (failed(maybeSize))
592  return emitOpError("static concatenation size mismatch along ")
593  << "non-concatenated dimension " << i;
594  size = *maybeSize;
595  }
596  sizes[i] = size.asInteger();
597  }
598  auto concatSize = SaturatedInteger::wrap(0);
599  for (auto tensorType : inputTypes)
600  concatSize =
601  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
602  sizes[dim] = concatSize.asInteger();
603  auto inferredResultType =
604  RankedTensorType::get(sizes, inputTypes[0].getElementType());
605 
606  for (auto [inferredSize, actualSize] :
607  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
608  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
609  ShapedType::isDynamic(actualSize);
610  if (!hasDynamic && inferredSize != actualSize)
611  return emitOpError("result type ")
612  << resultType << "does not match inferred shape "
613  << inferredResultType << " static sizes";
614  }
615 
616  return success();
617 }
618 
619 LogicalResult
621  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
622  ValueRange inputs = getInputs();
623  int64_t dim = getDim();
624  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
625 
626  Value init = inputs[0];
627  int64_t rank = getType().getRank();
628 
629  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
630 
631  // Pre-populate the result sizes with as much static information as possible
632  // from the given result type, as well as the inferred result type, otherwise
633  // use the dim sizes from the first input.
634  for (int64_t i = 0; i < rank; ++i) {
635  if (i == dim)
636  continue;
637  if (!getType().isDynamicDim(i)) {
638  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
639  } else if (!inferredResultType.isDynamicDim(i)) {
640  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
641  builder, getLoc(),
642  builder.getIndexAttr(inferredResultType.getDimSize(i)));
643  } else {
644  reifiedReturnShapes[0][i] =
645  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
646  }
647  }
648 
649  if (getType().isDynamicDim(dim)) {
650  // Take the sum of the input sizes along the concatenated dim.
651  AffineExpr sum = builder.getAffineDimExpr(0);
652  SmallVector<OpFoldResult> sizes = {
653  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
654  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
655  sum = sum + builder.getAffineDimExpr(idx + 1);
656  sizes.push_back(
657  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
658  }
659  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
660  builder, getLoc(),
661  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
662  } else {
663  // If the result shape is static along the concatenated dim, use the static
664  // shape.
665  reifiedReturnShapes[0][dim] =
666  builder.getIndexAttr(getType().getDimSize(dim));
667  }
668  return success();
669 }
670 
671 void ConcatOp::getAsmResultNames(
672  function_ref<void(Value, StringRef)> setNameFn) {
673  setNameFn(getResult(), "concat");
674 }
675 
676 OpFoldResult ConcatOp::fold(FoldAdaptor) {
677  ValueRange inputs = getInputs();
678  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
679  return inputs[0];
680  return {};
681 }
682 
683 namespace {
684 /// Fold a concat op with a single input to a cast.
685 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
687 
688  LogicalResult matchAndRewrite(ConcatOp concatOp,
689  PatternRewriter &rewriter) const override {
690  if (concatOp.getInputs().size() != 1)
691  return failure();
692  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
693  concatOp.getInputs()[0]);
694  return success();
695  }
696 };
697 } // namespace
698 
699 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
700  MLIRContext *context) {
701  results.add<SingleInputConcatOp>(context);
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // DimOp
706 //===----------------------------------------------------------------------===//
707 
708 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
709  setNameFn(getResult(), "dim");
710 }
711 
712 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
713  int64_t index) {
714  auto loc = result.location;
715  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
716  build(builder, result, source, indexValue);
717 }
718 
719 std::optional<int64_t> DimOp::getConstantIndex() {
720  return getConstantIntValue(getIndex());
721 }
722 
723 Speculation::Speculatability DimOp::getSpeculatability() {
724  auto constantIndex = getConstantIndex();
725  if (!constantIndex)
727 
728  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
729  if (!rankedSourceType)
731 
732  if (rankedSourceType.getRank() <= constantIndex)
734 
736 }
737 
738 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
739  // All forms of folding require a known index.
740  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
741  if (!index)
742  return {};
743 
744  // Folding for unranked types (UnrankedTensorType) is not supported.
745  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
746  if (!tensorType)
747  return {};
748 
749  // Out of bound indices produce undefined behavior but are still valid IR.
750  // Don't choke on them.
751  int64_t indexVal = index.getInt();
752  if (indexVal < 0 || indexVal >= tensorType.getRank())
753  return {};
754 
755  // Fold if the shape extent along the given index is known.
756  if (!tensorType.isDynamicDim(index.getInt())) {
757  Builder builder(getContext());
758  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
759  }
760 
761  Operation *definingOp = getSource().getDefiningOp();
762 
763  // Fold dim to the operand of tensor.generate.
764  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
765  auto resultType =
766  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
767  // The case where the type encodes the size of the dimension is handled
768  // above.
769  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
770 
771  // Find the operand of the fromElements that corresponds to this index.
772  auto dynExtents = fromElements.getDynamicExtents().begin();
773  for (auto dim : resultType.getShape().take_front(index.getInt()))
774  if (ShapedType::isDynamic(dim))
775  dynExtents++;
776 
777  return Value{*dynExtents};
778  }
779 
780  // The size at the given index is now known to be a dynamic size.
781  unsigned unsignedIndex = index.getValue().getZExtValue();
782 
783  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
784  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
785  // `resolve-shaped-type-result-dims` pass.
786  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
787  sliceOp.isDynamicSize(unsignedIndex)) {
788  return {sliceOp.getDynamicSize(unsignedIndex)};
789  }
790  }
791 
792  // dim(cast) -> dim
793  if (succeeded(foldTensorCast(*this)))
794  return getResult();
795 
796  return {};
797 }
798 
799 namespace {
800 /// Fold dim of a cast into the dim of the source of the tensor cast.
801 struct DimOfCastOp : public OpRewritePattern<DimOp> {
803 
804  LogicalResult matchAndRewrite(DimOp dimOp,
805  PatternRewriter &rewriter) const override {
806  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
807  if (!castOp)
808  return failure();
809  Value newSource = castOp.getOperand();
810  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
811  return success();
812  }
813 };
814 
815 /// Fold dim of a destination passing style op into the dim of the corresponding
816 /// init.
817 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
819 
820  LogicalResult matchAndRewrite(DimOp dimOp,
821  PatternRewriter &rewriter) const override {
822  auto source = dimOp.getSource();
823  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
824  if (!destOp)
825  return failure();
826 
827  auto resultIndex = cast<OpResult>(source).getResultNumber();
828  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
829 
830  rewriter.modifyOpInPlace(
831  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
832  return success();
833  }
834 };
835 
836 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
837 /// operand.
838 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
840 
841  LogicalResult matchAndRewrite(DimOp dim,
842  PatternRewriter &rewriter) const override {
843  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
844 
845  if (!reshape)
846  return failure();
847 
848  // Since tensors are immutable we don't need to worry about where to place
849  // the extract call
850  rewriter.setInsertionPointAfter(dim);
851  Location loc = dim.getLoc();
852  Value extract =
853  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
854  if (extract.getType() != dim.getType())
855  extract =
856  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
857  rewriter.replaceOp(dim, extract);
858  return success();
859  }
860 };
861 } // namespace
862 
863 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
864  MLIRContext *context) {
865  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // EmptyOp
870 //===----------------------------------------------------------------------===//
871 
872 void EmptyOp::build(OpBuilder &builder, OperationState &result,
873  ArrayRef<int64_t> staticShape, Type elementType,
874  Attribute encoding) {
875  assert(all_of(staticShape,
876  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
877  "expected only static sizes");
878  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
879 }
880 
881 void EmptyOp::build(OpBuilder &builder, OperationState &result,
882  ArrayRef<int64_t> staticShape, Type elementType,
883  ValueRange dynamicSizes, Attribute encoding) {
884  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
885  build(builder, result, tensorType, dynamicSizes);
886 }
887 
888 void EmptyOp::build(OpBuilder &builder, OperationState &result,
889  ArrayRef<OpFoldResult> sizes, Type elementType,
890  Attribute encoding) {
891  SmallVector<int64_t> staticShape;
892  SmallVector<Value> dynamicSizes;
893  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
894  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
895 }
896 
897 LogicalResult EmptyOp::verify() {
898  if (getType().getNumDynamicDims() !=
899  static_cast<int64_t>(getDynamicSizes().size()))
900  return emitOpError("incorrect number of dynamic sizes, has ")
901  << getDynamicSizes().size() << ", expected "
902  << getType().getNumDynamicDims();
903  return success();
904 }
905 
906 LogicalResult
908  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
909  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
910  unsigned ctr = 0;
911  for (int64_t i = 0; i < getType().getRank(); ++i) {
912  if (getType().isDynamicDim(i)) {
913  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
914  } else {
915  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
916  }
917  }
918  return success();
919 }
920 
921 Value EmptyOp::getDynamicSize(unsigned idx) {
922  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
923  unsigned ctr = 0;
924  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
925  if (getType().isDynamicDim(i))
926  ++ctr;
927  return getDynamicSizes()[ctr];
928 }
929 
932  unsigned ctr = 0;
933  OpBuilder b(getContext());
934  for (int64_t i = 0; i < getType().getRank(); ++i) {
935  if (getType().isDynamicDim(i)) {
936  result.push_back(getDynamicSizes()[ctr++]);
937  } else {
938  result.push_back(b.getIndexAttr(getType().getShape()[i]));
939  }
940  }
941  return result;
942 }
943 
944 namespace {
945 /// Change the type of the result of a `tensor.empty` by making the result
946 /// type statically sized along dimensions that in the original operation were
947 /// defined as dynamic, but the size was defined using a `constant` op. For
948 /// example
949 ///
950 /// %c5 = arith.constant 5: index
951 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
952 ///
953 /// to
954 ///
955 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
956 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
958 
959  LogicalResult matchAndRewrite(EmptyOp op,
960  PatternRewriter &rewriter) const override {
961  SmallVector<Value> foldedDynamicSizes;
962  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
963  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
964 
965  // Stop here if no dynamic size was promoted to static.
966  if (foldedTensorType == op.getType())
967  return failure();
968 
969  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
970  foldedDynamicSizes);
971  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
972  return success();
973  }
974 };
975 
976 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
978 
979  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
980  PatternRewriter &rewriter) const override {
981  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
982  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
983  if (!emptyTensorOp || !maybeConstantIndex)
984  return failure();
985  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
986  return failure();
987  rewriter.replaceOp(dimOp,
988  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
989  return success();
990  }
991 };
992 
993 /// Canonicalize
994 ///
995 /// ```mlir
996 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
997 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
998 /// ```
999 ///
1000 /// into
1001 ///
1002 /// ```mlir
1003 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1004 /// ```
1005 ///
1006 /// This assumes the input program is correct in terms of its shape. So it is
1007 /// safe to assume that `%d0` is in fact 4.
1008 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1010 
1011  LogicalResult matchAndRewrite(CastOp castOp,
1012  PatternRewriter &rewriter) const override {
1013  if (!canFoldIntoProducerOp(castOp))
1014  return failure();
1015  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1016  if (!producer)
1017  return failure();
1018 
1019  auto resultType =
1020  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1021  ArrayRef<int64_t> resultShape = resultType.getShape();
1022  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1023  SmallVector<OpFoldResult> newMixedSizes;
1024  newMixedSizes.reserve(currMixedSizes.size());
1025  assert(resultShape.size() == currMixedSizes.size() &&
1026  "mismatch in result shape and sizes of empty op");
1027  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1028  int64_t newDim = std::get<0>(it);
1029  OpFoldResult currDim = std::get<1>(it);
1030  // Case 1: The empty tensor dim is static. Check that the tensor cast
1031  // result dim matches.
1032  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1033  if (ShapedType::isDynamic(newDim) ||
1034  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1035  // Something is off, the cast result shape cannot be more dynamic
1036  // than the empty tensor result shape (enforced by
1037  // `canFoldIntoProducer`). Abort for now.
1038  return rewriter.notifyMatchFailure(
1039  producer, "mismatch in static value of shape of empty tensor "
1040  "result and cast result");
1041  }
1042  newMixedSizes.push_back(attr);
1043  continue;
1044  }
1045 
1046  // Case 2 : The tensor cast shape is static, but empty tensor result
1047  // shape is dynamic.
1048  if (!ShapedType::isDynamic(newDim)) {
1049  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1050  continue;
1051  }
1052 
1053  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1054  // shape is dynamic. Use the dynamic value from the empty tensor op.
1055  newMixedSizes.push_back(currDim);
1056  }
1057 
1058  // TODO: Do not drop tensor encoding.
1059  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1060  resultType.getElementType());
1061  return success();
1062  }
1063 };
1064 
1065 } // namespace
1066 
1067 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068  MLIRContext *context) {
1069  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1070  ReplaceEmptyTensorStaticShapeDims>(context);
1071 }
1072 
1073 /// Try to remove a tensor operation if it would only reshape a constant.
1074 /// Removes the op and replaces the constant with a new constant of the result
1075 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1076 /// splat value matches the value in the attribute.
1077 static OpFoldResult
1079  std::optional<Attribute> cst = std::nullopt) {
1080  if (source && source.isSplat() && result.hasStaticShape() &&
1081  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1082  return source.resizeSplat(result);
1083 
1084  return {};
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // ExtractOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 namespace {
1092 
1093 /// Canonicalizes the pattern of the form
1094 ///
1095 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1096 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1097 ///
1098 /// to
1099 ///
1100 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1101 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1103 
1104  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1105  PatternRewriter &rewriter) const final {
1106  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1107  if (!tensorCast)
1108  return failure();
1109  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1110  return failure();
1111  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1112  extract, tensorCast.getSource(), extract.getIndices());
1113  return success();
1114  }
1115 };
1116 
1117 } // namespace
1118 
1119 void ExtractOp::getAsmResultNames(
1120  function_ref<void(Value, StringRef)> setNameFn) {
1121  setNameFn(getResult(), "extracted");
1122 }
1123 
1124 LogicalResult ExtractOp::verify() {
1125  // Verify the # indices match if we have a ranked type.
1126  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1127  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1128  return emitOpError("incorrect number of indices for extract_element");
1129  return success();
1130 }
1131 
1132 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1133  // If this is a splat elements attribute, simply return the value. All of
1134  // the elements of a splat attribute are the same.
1135  if (Attribute tensor = adaptor.getTensor())
1136  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1137  return splatTensor.getSplatValue<Attribute>();
1138 
1139  // Collect the constant indices into the tensor.
1140  SmallVector<uint64_t, 8> indices;
1141  for (Attribute indice : adaptor.getIndices()) {
1142  if (!indice || !llvm::isa<IntegerAttr>(indice))
1143  return {};
1144  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1145  }
1146 
1147  // Fold extract(from_elements(...)).
1148  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1149  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1150  auto rank = tensorType.getRank();
1151  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1152  "rank mismatch");
1153  int flatIndex = 0;
1154  int stride = 1;
1155  for (int i = rank - 1; i >= 0; --i) {
1156  flatIndex += indices[i] * stride;
1157  stride *= tensorType.getDimSize(i);
1158  }
1159  // Prevent out of bounds accesses. This can happen in invalid code that
1160  // will never execute.
1161  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1162  flatIndex < 0)
1163  return {};
1164  return fromElementsOp.getElements()[flatIndex];
1165  }
1166 
1167  // If this is an elements attribute, query the value at the given indices.
1168  if (Attribute tensor = adaptor.getTensor()) {
1169  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1170  if (elementsAttr && elementsAttr.isValidIndex(indices))
1171  return elementsAttr.getValues<Attribute>()[indices];
1172  }
1173 
1174  return {};
1175 }
1176 
1177 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1178  MLIRContext *context) {
1179  results.add<ExtractFromTensorCast>(context);
1180 }
1181 
1182 //===----------------------------------------------------------------------===//
1183 // FromElementsOp
1184 //===----------------------------------------------------------------------===//
1185 
1186 void FromElementsOp::getAsmResultNames(
1187  function_ref<void(Value, StringRef)> setNameFn) {
1188  setNameFn(getResult(), "from_elements");
1189 }
1190 
1191 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1192  ValueRange elements) {
1193  assert(!elements.empty() && "expected at least one element");
1194  Type resultType = RankedTensorType::get(
1195  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1196  build(builder, result, resultType, elements);
1197 }
1198 
1199 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1200  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1201  return DenseElementsAttr::get(getType(), adaptor.getElements());
1202  return {};
1203 }
1204 
1205 namespace {
1206 
1207 // Pushes the index_casts that occur before extractions to after the extract.
1208 // This minimizes type conversion in some cases and enables the extract
1209 // canonicalizer. This changes:
1210 //
1211 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1212 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1213 //
1214 // to the following:
1215 //
1216 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1217 // %cast = arith.index_cast %extract : i32 to index
1218 //
1219 // to just %element.
1220 //
1221 // Consider expanding this to a template and handle all tensor cast
1222 // operations.
1223 struct ExtractElementFromIndexCast
1224  : public OpRewritePattern<tensor::ExtractOp> {
1226 
1227  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1228  PatternRewriter &rewriter) const final {
1229  Location loc = extract.getLoc();
1230  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1231  if (!indexCast)
1232  return failure();
1233 
1234  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1235 
1236  auto newExtract = rewriter.create<tensor::ExtractOp>(
1237  loc, elementTy, indexCast.getIn(), extract.getIndices());
1238 
1239  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1240  newExtract);
1241 
1242  return success();
1243  }
1244 };
1245 
1246 } // namespace
1247 
1248 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249  MLIRContext *context) {
1250  results.add<ExtractElementFromIndexCast>(context);
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // GatherOp
1255 //===----------------------------------------------------------------------===//
1256 
1257 void GatherOp::getAsmResultNames(
1258  function_ref<void(Value, StringRef)> setNameFn) {
1259  setNameFn(getResult(), "gather");
1260 }
1261 
1262 /// Return the inferred result type for a gatherOp where:
1263 /// - sourceType is the type of the source tensor gathered from
1264 /// - indicesType is the type of the indices used to gather
1265 /// - gatherDims are the dims along which the gather occurs.
1266 /// Return a full rank or ranked-reduced variant of the type depending on
1267 /// the value of rankReduced.
1268 ///
1269 /// The leading dimensions of the index tensor give the result tensor its
1270 /// leading dimensions.
1271 /// The trailing dimensions of the result tensor are obtained from the source
1272 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1273 /// rankedReduced is false), or skipping them (otherwise).
1274 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1275  RankedTensorType indicesType,
1276  ArrayRef<int64_t> gatherDims,
1277  bool rankReduced) {
1278  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1279  resultShape.reserve(resultShape.size() + sourceType.getRank());
1280  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1281  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1282  if (!rankReduced)
1283  resultShape.push_back(1);
1284  continue;
1285  }
1286  resultShape.push_back(sourceType.getDimSize(idx));
1287  }
1288  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1289 }
1290 
1291 static LogicalResult
1293  StringRef gatherOrScatter, StringRef sourceOrDest) {
1294  if (dims.empty())
1295  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1296 
1297  int64_t numGatherDims = dims.size();
1298  if (numGatherDims > rank)
1299  return op->emitOpError(gatherOrScatter)
1300  << "_dims overflow " << sourceOrDest << " rank";
1301  for (int64_t val : dims) {
1302  if (val < 0)
1303  return op->emitOpError(gatherOrScatter)
1304  << "_dims value must be non-negative";
1305  if (val >= rank)
1306  return op->emitOpError(gatherOrScatter)
1307  << "_dims value must be smaller than " << sourceOrDest << " rank";
1308  }
1309  for (int64_t i = 1; i < numGatherDims; ++i) {
1310  if (dims[i - 1] >= dims[i])
1311  return op->emitOpError(gatherOrScatter)
1312  << "_dims values must be strictly increasing";
1313  }
1314  return success();
1315 }
1316 
1317 LogicalResult GatherOp::verify() {
1318  int64_t sourceRank = getSourceType().getRank();
1319  ArrayRef<int64_t> gatherDims = getGatherDims();
1320  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1321  "gather", "source")))
1322  return failure();
1323 
1324  RankedTensorType expectedResultType = GatherOp::inferResultType(
1325  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1326  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1327  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1328  if (getResultType() != expectedResultType &&
1329  getResultType() != expectedRankReducedResultType) {
1330  return emitOpError("result type "
1331  "mismatch: "
1332  "expected ")
1333  << expectedResultType << " or its rank-reduced variant "
1334  << expectedRankReducedResultType << " (got: " << getResultType()
1335  << ")";
1336  }
1337 
1338  return success();
1339 }
1340 
1341 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1342  if (OpFoldResult reshapedSource = reshapeConstantSource(
1343  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1344  getResult().getType()))
1345  return reshapedSource;
1346  return {};
1347 }
1348 
1349 //===----------------------------------------------------------------------===//
1350 // InsertOp
1351 //===----------------------------------------------------------------------===//
1352 
1353 void InsertOp::getAsmResultNames(
1354  function_ref<void(Value, StringRef)> setNameFn) {
1355  setNameFn(getResult(), "inserted");
1356 }
1357 
1358 LogicalResult InsertOp::verify() {
1359  // Verify the # indices match if we have a ranked type.
1360  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1361  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1362  return emitOpError("incorrect number of indices");
1363  return success();
1364 }
1365 
1366 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1367  Attribute scalar = adaptor.getScalar();
1368  Attribute dest = adaptor.getDest();
1369  if (scalar && dest)
1370  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1371  if (scalar == splatDest.getSplatValue<Attribute>())
1372  return dest;
1373  return {};
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // GenerateOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 void GenerateOp::getAsmResultNames(
1381  function_ref<void(Value, StringRef)> setNameFn) {
1382  setNameFn(getResult(), "generated");
1383 }
1384 
1385 LogicalResult GenerateOp::reifyResultShapes(
1386  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1387  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1388  int idx = 0;
1389  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1390  if (getType().isDynamicDim(dim)) {
1391  reifiedReturnShapes[0][dim] = getOperand(idx++);
1392  } else {
1393  reifiedReturnShapes[0][dim] =
1394  builder.getIndexAttr(getType().getDimSize(dim));
1395  }
1396  }
1397  return success();
1398 }
1399 
1400 LogicalResult GenerateOp::verify() {
1401  // Ensure that the tensor type has as many dynamic dimensions as are
1402  // specified by the operands.
1403  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1404  if (getNumOperands() != resultType.getNumDynamicDims())
1405  return emitError("must have as many index operands as dynamic extents "
1406  "in the result type");
1407  return success();
1408 }
1409 
1410 LogicalResult GenerateOp::verifyRegions() {
1411  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1412  // Ensure that region arguments span the index space.
1413  if (!llvm::all_of(getBody().getArgumentTypes(),
1414  [](Type ty) { return ty.isIndex(); }))
1415  return emitError("all body arguments must be index");
1416  if (getBody().getNumArguments() != resultTy.getRank())
1417  return emitError("must have one body argument per input dimension");
1418 
1419  // Ensure that the region yields an element of the right type.
1420  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1421 
1422  if (yieldOp.getValue().getType() != resultTy.getElementType())
1423  return emitOpError(
1424  "body must be terminated with a `yield` operation of the tensor "
1425  "element type");
1426 
1427  return success();
1428 }
1429 
1430 void GenerateOp::build(
1431  OpBuilder &b, OperationState &result, Type resultTy,
1432  ValueRange dynamicExtents,
1433  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1434  build(b, result, resultTy, dynamicExtents);
1435 
1436  // Build and populate body.
1437  OpBuilder::InsertionGuard guard(b);
1438  Region *bodyRegion = result.regions.front().get();
1439  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1440  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1441  SmallVector<Location, 2> argumentLocs(rank, result.location);
1442  Block *bodyBlock =
1443  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1444  bodyBuilder(b, result.location, bodyBlock->getArguments());
1445 }
1446 
1447 namespace {
1448 
1449 /// Canonicalizes tensor.generate operations with a constant
1450 /// operand into the equivalent operation with the operand expressed in the
1451 /// result type, instead. We also insert a type cast to make sure that the
1452 /// resulting IR is still well-typed.
1453 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1455 
1456  LogicalResult matchAndRewrite(GenerateOp generateOp,
1457  PatternRewriter &rewriter) const final {
1458  SmallVector<Value> foldedDynamicSizes;
1459  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1460  generateOp.getType(), generateOp.getDynamicExtents(),
1461  foldedDynamicSizes);
1462 
1463  // Stop here if no dynamic size was promoted to static.
1464  if (foldedTensorType == generateOp.getType())
1465  return failure();
1466 
1467  auto loc = generateOp.getLoc();
1468  auto newOp =
1469  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1470  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1471  newOp.getBody().begin());
1472  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1473  generateOp.getType(), newOp);
1474  return success();
1475  }
1476 };
1477 
1478 /// Canonicalizes the pattern of the form
1479 ///
1480 /// %tensor = tensor.generate %x {
1481 /// ^bb0(%arg0: index):
1482 /// <computation>
1483 /// yield %1 : index
1484 /// } : tensor<?xindex>
1485 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1486 ///
1487 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1488 /// tensor.generate operation has no side-effects.
1489 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1491 
1492  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1493  PatternRewriter &rewriter) const final {
1494  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1495  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1496  return failure();
1497 
1498  IRMapping mapping;
1499  Block *body = &tensorFromElements.getBody().front();
1500  mapping.map(body->getArguments(), extract.getIndices());
1501  for (auto &op : body->without_terminator())
1502  rewriter.clone(op, mapping);
1503 
1504  auto yield = cast<YieldOp>(body->getTerminator());
1505 
1506  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1507  return success();
1508  }
1509 };
1510 
1511 } // namespace
1512 
1513 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1514  MLIRContext *context) {
1515  // TODO: Move extract pattern to tensor::ExtractOp.
1516  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // RankOp
1521 //===----------------------------------------------------------------------===//
1522 
1523 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1524  setNameFn(getResult(), "rank");
1525 }
1526 
1527 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1528  // Constant fold rank when the rank of the operand is known.
1529  auto type = getOperand().getType();
1530  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1531  if (shapedType && shapedType.hasRank())
1532  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1533  return IntegerAttr();
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // ReshapeOp
1538 //===----------------------------------------------------------------------===//
1539 
1540 void ReshapeOp::getAsmResultNames(
1541  function_ref<void(Value, StringRef)> setNameFn) {
1542  setNameFn(getResult(), "reshape");
1543 }
1544 
1545 static int64_t getNumElements(ShapedType type) {
1546  int64_t numElements = 1;
1547  for (auto dim : type.getShape())
1548  numElements *= dim;
1549  return numElements;
1550 }
1551 
1552 LogicalResult ReshapeOp::verify() {
1553  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1554  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1555 
1556  if (operandType.getElementType() != resultType.getElementType())
1557  return emitOpError("element types of source and destination tensor "
1558  "types should be the same");
1559 
1560  int64_t shapeSize =
1561  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1562  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1563  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1564 
1565  if (resultRankedType) {
1566  if (operandRankedType && resultRankedType.hasStaticShape() &&
1567  operandRankedType.hasStaticShape()) {
1568  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1569  return emitOpError("source and destination tensor should have the "
1570  "same number of elements");
1571  }
1572  if (ShapedType::isDynamic(shapeSize))
1573  return emitOpError("cannot use shape operand with dynamic length to "
1574  "reshape to statically-ranked tensor type");
1575  if (shapeSize != resultRankedType.getRank())
1576  return emitOpError(
1577  "length of shape operand differs from the result's tensor rank");
1578  }
1579  return success();
1580 }
1581 
1582 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1583  if (OpFoldResult reshapedSource = reshapeConstantSource(
1584  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1585  getResult().getType()))
1586  return reshapedSource;
1587 
1588  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1589  // producer's input instead as the original tensor to reshape. This could
1590  // render such producer dead code.
1591  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1592  getSourceMutable().assign(reshapeOpProducer.getSource());
1593  return getResult();
1594  }
1595 
1596  auto source = getSource();
1597  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1598  auto resultTy = dyn_cast<RankedTensorType>(getType());
1599  if (!sourceTy || !resultTy || sourceTy != resultTy)
1600  return {};
1601 
1602  // If the source and result are both 1D tensors and have the same type, the
1603  // reshape has no effect, even if the tensor is dynamically shaped.
1604  if (sourceTy.getRank() == 1)
1605  return source;
1606 
1607  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1608  auto elements = fromElements.getElements();
1609  bool dynamicNoop =
1610  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1611  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1612  auto element = elements[id];
1613 
1614  if (auto cst = getConstantIntValue(element)) {
1615  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1616  continue;
1617  }
1618 
1619  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1620  dynamicNoop &= dimOp.getSource() == source;
1621 
1622  APSInt dim;
1623  auto cst = getConstantIntValue(dimOp.getIndex());
1624  dynamicNoop &=
1625  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1626  continue;
1627  }
1628 
1629  dynamicNoop = false;
1630  break;
1631  }
1632 
1633  if (dynamicNoop)
1634  return source;
1635  }
1636 
1637  return {};
1638 }
1639 
1640 //===----------------------------------------------------------------------===//
1641 // Reassociative reshape ops
1642 //===----------------------------------------------------------------------===//
1643 
1644 void CollapseShapeOp::getAsmResultNames(
1645  function_ref<void(Value, StringRef)> setNameFn) {
1646  setNameFn(getResult(), "collapsed");
1647 }
1648 
1649 void ExpandShapeOp::getAsmResultNames(
1650  function_ref<void(Value, StringRef)> setNameFn) {
1651  setNameFn(getResult(), "expanded");
1652 }
1653 
1654 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1655  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1656  "invalid resultDim");
1657  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1658  if (llvm::is_contained(it.value(), resultDim))
1659  return it.index();
1660  llvm_unreachable("could not find reassociation group");
1661 }
1662 
1663 FailureOr<SmallVector<OpFoldResult>>
1664 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1665  RankedTensorType expandedType,
1666  ArrayRef<ReassociationIndices> reassociation,
1667  ArrayRef<OpFoldResult> inputShape) {
1668  std::optional<SmallVector<OpFoldResult>> outputShape =
1669  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1670  inputShape);
1671  if (!outputShape)
1672  return failure();
1673  return *outputShape;
1674 }
1675 
1676 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1677  Type resultType, Value src,
1678  ArrayRef<ReassociationIndices> reassociation,
1679  ArrayRef<OpFoldResult> outputShape) {
1680  auto [staticOutputShape, dynamicOutputShape] =
1682  build(builder, result, cast<RankedTensorType>(resultType), src,
1683  getReassociationIndicesAttribute(builder, reassociation),
1684  dynamicOutputShape, staticOutputShape);
1685 }
1686 
1687 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1688  Type resultType, Value src,
1689  ArrayRef<ReassociationIndices> reassociation) {
1690  SmallVector<OpFoldResult> inputShape =
1691  getMixedSizes(builder, result.location, src);
1692  auto tensorResultTy = cast<RankedTensorType>(resultType);
1693  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1694  builder, result.location, tensorResultTy, reassociation, inputShape);
1695  SmallVector<OpFoldResult> outputShapeOrEmpty;
1696  if (succeeded(outputShape)) {
1697  outputShapeOrEmpty = *outputShape;
1698  }
1699  build(builder, result, tensorResultTy, src, reassociation,
1700  outputShapeOrEmpty);
1701 }
1702 
1703 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1704  return getSymbolLessAffineMaps(getReassociationExprs());
1705 }
1706 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1708  getReassociationIndices());
1709 }
1710 
1711 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1712  return getSymbolLessAffineMaps(getReassociationExprs());
1713 }
1714 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1716  getReassociationIndices());
1717 }
1718 
1719 RankedTensorType CollapseShapeOp::inferCollapsedType(
1720  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1721  return inferCollapsedType(
1723  type.getContext(), reassociation)));
1724 }
1725 
1726 /// Compute the RankedTensorType obtained by applying `reassociation` to
1727 /// `type`.
1728 RankedTensorType
1729 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1730  ArrayRef<AffineMap> reassociation) {
1731  auto shape = type.getShape();
1732  SmallVector<int64_t, 4> newShape;
1733  newShape.reserve(reassociation.size());
1734 
1735  // Use the fact that reassociation is valid to simplify the logic: only use
1736  // each map's rank.
1737  assert(isReassociationValid(reassociation) && "invalid reassociation");
1738  unsigned currentDim = 0;
1739  for (AffineMap m : reassociation) {
1740  unsigned dim = m.getNumResults();
1741  auto band = shape.slice(currentDim, dim);
1742  int64_t size = 1;
1743  if (llvm::is_contained(band, ShapedType::kDynamic))
1744  size = ShapedType::kDynamic;
1745  else
1746  for (unsigned d = 0; d < dim; ++d)
1747  size *= shape[currentDim + d];
1748  newShape.push_back(size);
1749  currentDim += dim;
1750  }
1751 
1752  return RankedTensorType::get(newShape, type.getElementType());
1753 }
1754 
1755 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1756  ArrayRef<ReassociationIndices> reassociation,
1757  ArrayRef<NamedAttribute> attrs) {
1758  auto resultType = inferCollapsedType(
1759  llvm::cast<RankedTensorType>(src.getType()),
1761  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1762  result.addAttribute(getReassociationAttrStrName(),
1763  getReassociationIndicesAttribute(b, reassociation));
1764  build(b, result, resultType, src, attrs);
1765 }
1766 
1767 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1768  TensorReshapeOp, ExpandShapeOp>::value>
1769 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1770  RankedTensorType expandedType,
1771  RankedTensorType collapsedType) {
1772  if (failed(
1773  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1774  return failure();
1775 
1776  auto maps = op.getReassociationMaps();
1777  RankedTensorType expectedType =
1778  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1779  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1780  return op.emitOpError("expected collapsed type to be ")
1781  << expectedType << ", but got " << collapsedType;
1782  return success();
1783 }
1784 
1785 LogicalResult ExpandShapeOp::verify() {
1786  auto srcType = getSrcType();
1787  auto resultType = getResultType();
1788 
1789  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1790  return emitOpError("expected number of static shape dims to be equal to "
1791  "the output rank (")
1792  << resultType.getRank() << ") but found "
1793  << getStaticOutputShape().size() << " inputs instead";
1794 
1795  if ((int64_t)getOutputShape().size() !=
1796  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1797  return emitOpError("mismatch in dynamic dims in output_shape and "
1798  "static_output_shape: static_output_shape has ")
1799  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1800  << " dynamic dims while output_shape has " << getOutputShape().size()
1801  << " values";
1802 
1803  return verifyTensorReshapeOp(*this, resultType, srcType);
1804 }
1805 
1806 LogicalResult CollapseShapeOp::verify() {
1807  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1808 }
1809 
1810 namespace {
1811 /// Reshape of a splat constant can be replaced with a constant of the result
1812 /// type.
1813 template <typename TensorReshapeOp>
1814 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1816  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1817  PatternRewriter &rewriter) const override {
1818  DenseElementsAttr attr;
1819  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1820  return failure();
1821  if (!attr || !attr.isSplat())
1822  return failure();
1824  reshapeOp.getResultType(), attr.getRawData());
1825  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1826  return success();
1827  }
1828 };
1829 
1830 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1831 template <typename TensorReshapeOp>
1832 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1833 public:
1835 
1836  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1837  PatternRewriter &rewriter) const override {
1838  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1839  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1840  return failure();
1841 
1842  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1843  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1844  return success();
1845  }
1846 };
1847 
1848 /// Reshape of a FromElements can be replaced with a FromElements of the
1849 /// result type
1850 template <typename TensorReshapeOp>
1851 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1853  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1854  PatternRewriter &rewriter) const override {
1855  auto fromElements =
1856  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1857  if (!fromElements)
1858  return failure();
1859 
1860  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1861 
1862  if (!shapedTy.hasStaticShape())
1863  return failure();
1864 
1865  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1866  fromElements.getElements());
1867  return success();
1868  }
1869 };
1870 
1871 // Fold CastOp into CollapseShapeOp when adding static information.
1872 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1874 
1875  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1876  PatternRewriter &rewriter) const override {
1877  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1878  if (!tensor::canFoldIntoConsumerOp(castOp))
1879  return failure();
1880 
1881  RankedTensorType srcType =
1882  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1883  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1884  srcType, collapseShapeOp.getReassociationMaps());
1885 
1886  if (newResultType == collapseShapeOp.getResultType()) {
1887  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1888  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1889  });
1890  } else {
1891  auto newOp = rewriter.create<CollapseShapeOp>(
1892  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1893  collapseShapeOp.getReassociation());
1894  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1895  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1896  }
1897  return success();
1898  }
1899 };
1900 
1901 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1903 
1904  LogicalResult matchAndRewrite(DimOp dimOp,
1905  PatternRewriter &rewriter) const override {
1906  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1907  if (!expandShapeOp)
1908  return failure();
1909 
1910  // Only constant dimension values are supported.
1911  std::optional<int64_t> dim = dimOp.getConstantIndex();
1912  if (!dim.has_value())
1913  return failure();
1914 
1915  // Skip static dims. These are folded to constant ops.
1916  RankedTensorType resultType = expandShapeOp.getResultType();
1917  if (!resultType.isDynamicDim(*dim))
1918  return failure();
1919 
1920  // Find reassociation group that contains this result dimension.
1921  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1922 
1923  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1924  // ExpandShapeOp would be ambiguous.)
1925  int64_t product = 1;
1926  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1927  for (int64_t d : grp) {
1928  if (d != dim) {
1929  assert(!resultType.isDynamicDim(d) && "expected static dim");
1930  product *= resultType.getDimSize(d);
1931  }
1932  }
1933 
1934  // result dim size = src dim size / (product(other dims in reassoc group))
1935  Value srcDimSz =
1936  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1937  AffineExpr expr;
1938  bindSymbols(dimOp.getContext(), expr);
1939  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1940  dimOp, expr.floorDiv(product), srcDimSz);
1941  return success();
1942  }
1943 };
1944 
1945 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1947 
1948  LogicalResult matchAndRewrite(DimOp dimOp,
1949  PatternRewriter &rewriter) const override {
1950  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1951  if (!collapseShapeOp)
1952  return failure();
1953 
1954  // Only constant dimension values are supported.
1955  std::optional<int64_t> dim = dimOp.getConstantIndex();
1956  if (!dim.has_value())
1957  return failure();
1958 
1959  // Skip static dims. These are folded to constant ops.
1960  RankedTensorType resultType = collapseShapeOp.getResultType();
1961  if (!resultType.isDynamicDim(*dim))
1962  return failure();
1963 
1964  // Get reassociation group of the result dimension.
1965  ReassociationIndices group =
1966  collapseShapeOp.getReassociationIndices()[*dim];
1967 
1968  // result dim size = product(dims in reassoc group)
1969  SmallVector<Value> srcDimSizes;
1972  for (const auto &it : llvm::enumerate(group)) {
1973  srcDimSizes.push_back(rewriter.create<DimOp>(
1974  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1975  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1976  product = product ? product * syms.back() : syms.back();
1977  }
1978  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1979  srcDimSizes);
1980  return success();
1981  }
1982 };
1983 } // namespace
1984 
1985 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1986  MLIRContext *context) {
1987  results.add<
1990  FoldReshapeWithConstant<ExpandShapeOp>,
1991  FoldReshapeWithSplat<ExpandShapeOp>,
1992  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1993  FoldDimOfCollapseShape>(context);
1994 }
1995 
1996 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1997  MLIRContext *context) {
1998  results.add<
2000  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2001  tensor::DimOp, RankedTensorType>,
2002  FoldReshapeWithConstant<CollapseShapeOp>,
2003  FoldReshapeWithSplat<CollapseShapeOp>,
2004  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2005  context);
2006 }
2007 
2008 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2009  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2010  adaptor.getOperands());
2011 }
2012 
2013 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2014  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2015  adaptor.getOperands());
2016 }
2017 
2018 //===----------------------------------------------------------------------===//
2019 // ExtractSliceOp
2020 //===----------------------------------------------------------------------===//
2021 
2022 void ExtractSliceOp::getAsmResultNames(
2023  function_ref<void(Value, StringRef)> setNameFn) {
2024  setNameFn(getResult(), "extracted_slice");
2025 }
2026 
2027 /// An extract_slice result type can be inferred, when it is not
2028 /// rank-reduced, from the source type and the static representation of
2029 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2030 RankedTensorType ExtractSliceOp::inferResultType(
2031  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2032  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2033  // An extract_slice op may specify only a leading subset of offset/sizes/
2034  // strides in which case we complete with offset=0, sizes from memref type
2035  // and strides=1.
2036  assert(static_cast<int64_t>(staticSizes.size()) ==
2037  sourceTensorType.getRank() &&
2038  "unexpected staticSizes not equal to rank of source");
2039  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2040  sourceTensorType.getEncoding());
2041 }
2042 
2043 RankedTensorType ExtractSliceOp::inferResultType(
2044  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2046  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2047  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2048  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2049  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2050  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2051  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2052  staticSizes, staticStrides);
2053 }
2054 
2055 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2056 /// number of sizes), drop as many size 1 as needed to produce an inferred
2057 /// type with the desired rank.
2058 ///
2059 /// Note that there may be multiple ways to compute this rank-reduced type:
2060 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2061 ///
2062 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2063 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2064  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2065  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2066  ArrayRef<int64_t> strides) {
2067  // Type inferred in the absence of rank-reducing behavior.
2068  auto inferredType = llvm::cast<RankedTensorType>(
2069  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2070  int rankDiff = inferredType.getRank() - desiredResultRank;
2071  if (rankDiff > 0) {
2072  auto shape = inferredType.getShape();
2073  llvm::SmallBitVector dimsToProject =
2074  getPositionsOfShapeOne(rankDiff, shape);
2075  SmallVector<int64_t> projectedShape;
2076  // Best effort rank-reducing: drop 1s in order.
2077  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2078  if (!dimsToProject.test(pos))
2079  projectedShape.push_back(shape[pos]);
2080  inferredType =
2081  RankedTensorType::get(projectedShape, inferredType.getElementType());
2082  }
2083  return inferredType;
2084 }
2085 
2086 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2087  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2089  ArrayRef<OpFoldResult> strides) {
2090  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2091  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2092  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2093  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2094  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2095  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2096  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2097  staticStrides);
2098 }
2099 
2100 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2101 /// result type. If the type passed is nullptr, it is inferred.
2102 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2103  RankedTensorType resultType, Value source,
2104  ArrayRef<OpFoldResult> offsets,
2105  ArrayRef<OpFoldResult> sizes,
2106  ArrayRef<OpFoldResult> strides,
2107  ArrayRef<NamedAttribute> attrs) {
2108  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2109  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2110  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2111  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2112  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2113  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2114  // Structuring implementation this way avoids duplication between builders.
2115  if (!resultType) {
2116  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2117  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2118  }
2119  result.addAttributes(attrs);
2120  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2121  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2122  b.getDenseI64ArrayAttr(staticSizes),
2123  b.getDenseI64ArrayAttr(staticStrides));
2124 }
2125 
2126 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2127 /// result type.
2128 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2129  ArrayRef<OpFoldResult> offsets,
2130  ArrayRef<OpFoldResult> sizes,
2131  ArrayRef<OpFoldResult> strides,
2132  ArrayRef<NamedAttribute> attrs) {
2133  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2134 }
2135 
2136 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2137 /// a Range vector.
2138 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2139  ArrayRef<Range> ranges,
2140  ArrayRef<NamedAttribute> attrs) {
2141  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2142  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2143 }
2144 
2145 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2146 /// the type passed is nullptr, it is inferred.
2147 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2148  RankedTensorType resultType, Value source,
2149  ValueRange offsets, ValueRange sizes,
2150  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2151  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2152  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2153  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2154  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2155  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2156  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2157  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2158 }
2159 
2160 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2161 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2162  ValueRange offsets, ValueRange sizes,
2163  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2164  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2165 }
2166 
2167 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2168  Operation *op,
2169  RankedTensorType expectedType) {
2170  switch (result) {
2172  return success();
2174  return op->emitError("expected rank to be smaller or equal to ")
2175  << "the other rank. ";
2177  return op->emitError("expected type to be ")
2178  << expectedType << " or a rank-reduced version. (size mismatch) ";
2180  return op->emitError("expected element type to be ")
2181  << expectedType.getElementType();
2182  default:
2183  llvm_unreachable("unexpected extract_slice op verification result");
2184  }
2185 }
2186 
2187 /// Verifier for ExtractSliceOp.
2188 LogicalResult ExtractSliceOp::verify() {
2189  // Verify result type against inferred type.
2190  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2191  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2192  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2193  return produceSliceErrorMsg(result, *this, expectedType);
2194 }
2195 
2196 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2198 }
2199 
2200 FailureOr<Value>
2201 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2202  ArrayRef<int64_t> desiredShape) {
2203  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2204  assert(sourceTensorType && "not a ranked tensor type");
2205  auto sourceShape = sourceTensorType.getShape();
2206  if (sourceShape.equals(desiredShape))
2207  return value;
2208  auto maybeRankReductionMask =
2209  mlir::computeRankReductionMask(sourceShape, desiredShape);
2210  if (!maybeRankReductionMask)
2211  return failure();
2213  b, loc, value,
2214  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2215 }
2216 
2217 LogicalResult ExtractSliceOp::reifyResultShapes(
2218  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2219  reifiedReturnShapes.resize(1);
2220  reifiedReturnShapes[0].reserve(getType().getRank());
2222  llvm::SmallBitVector droppedDims = getDroppedDims();
2223  for (const auto &size : enumerate(mixedSizes)) {
2224  if (droppedDims.test(size.index()))
2225  continue;
2226  reifiedReturnShapes[0].push_back(size.value());
2227  }
2228  return success();
2229 }
2230 
2231 namespace {
2232 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2233 /// This essentially pushes memref_cast past its consuming slice when
2234 /// `canFoldIntoConsumerOp` is true.
2235 ///
2236 /// Example:
2237 /// ```
2238 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2239 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2240 /// tensor<3x4xf32>
2241 /// ```
2242 /// is rewritten into:
2243 /// ```
2244 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2245 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2246 /// ```
2247 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2248 public:
2250 
2251  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2252  PatternRewriter &rewriter) const override {
2253  // Any constant operand, just return to let the constant folder kick in.
2254  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2255  return matchPattern(operand, matchConstantIndex());
2256  }))
2257  return failure();
2258 
2259  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2260  if (!castOp)
2261  return failure();
2262 
2263  if (!canFoldIntoConsumerOp(castOp))
2264  return failure();
2265 
2266  // Create folded extract.
2267  Location loc = sliceOp.getLoc();
2268  Value newResult = rewriter.create<ExtractSliceOp>(
2269  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2270  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2271  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2272  if (newResult.getType() != sliceOp.getType())
2273  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2274  rewriter.replaceOp(sliceOp, newResult);
2275  return success();
2276  }
2277 };
2278 
2279 /// Slice elements from `values` into `outValues`. `counts` represents the
2280 /// numbers of elements to stride in the original values for each dimension.
2281 /// The output values can be used to construct a DenseElementsAttr.
2282 template <typename IterTy, typename ElemTy>
2283 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2284  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2285  ArrayRef<int64_t> strides,
2286  llvm::SmallVectorImpl<ElemTy> *outValues) {
2287  assert(offsets.size() == sizes.size());
2288  assert(offsets.size() == strides.size());
2289  if (offsets.empty())
2290  return;
2291 
2292  int64_t offset = offsets.front();
2293  int64_t size = sizes.front();
2294  int64_t stride = strides.front();
2295  if (offsets.size() == 1) {
2296  for (int64_t i = 0; i < size; ++i, offset += stride)
2297  outValues->push_back(*(values + offset));
2298 
2299  return;
2300  }
2301 
2302  for (int64_t i = 0; i < size; ++i, offset += stride) {
2303  auto begin = values + offset * counts.front();
2304  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2305  offsets.drop_front(), sizes.drop_front(),
2306  strides.drop_front(), outValues);
2307  }
2308 }
2309 
2310 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2311 /// folded operation might introduce more constant data; Users can control
2312 /// their heuristics by the control function.
2313 class ConstantOpExtractSliceFolder final
2314  : public OpRewritePattern<ExtractSliceOp> {
2315 public:
2317 
2318  ConstantOpExtractSliceFolder(MLIRContext *context,
2320  : OpRewritePattern<ExtractSliceOp>(context),
2321  controlFn(std::move(controlFn)) {}
2322 
2323  LogicalResult matchAndRewrite(ExtractSliceOp op,
2324  PatternRewriter &rewriter) const override {
2325  DenseElementsAttr attr;
2326  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2327  return failure();
2328 
2329  // A constant splat is handled by fold().
2330  if (attr.isSplat())
2331  return failure();
2332 
2333  // Dynamic result shape is not supported.
2334  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2335  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2336  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2337  return failure();
2338 
2339  // Customized control over the folding.
2340  if (!controlFn(op))
2341  return failure();
2342 
2343  int64_t count = sourceType.getNumElements();
2344  if (count == 0)
2345  return failure();
2346 
2347  // Check if there are any dynamic parts, which are not supported.
2348  auto offsets = op.getStaticOffsets();
2349  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2350  return failure();
2351  auto sizes = op.getStaticSizes();
2352  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2353  return failure();
2354  auto strides = op.getStaticStrides();
2355  if (llvm::is_contained(strides, ShapedType::kDynamic))
2356  return failure();
2357 
2358  // Compute the stride for each dimension.
2359  SmallVector<int64_t> counts;
2360  ArrayRef<int64_t> shape = sourceType.getShape();
2361  counts.reserve(shape.size());
2362  for (int64_t v : shape) {
2363  count = count / v;
2364  counts.push_back(count);
2365  }
2366 
2367  // New attribute constructed by the sliced values.
2368  DenseElementsAttr newAttr;
2369 
2370  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2371  SmallVector<APInt> outValues;
2372  outValues.reserve(sourceType.getNumElements());
2373  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2374  elems.begin(), counts, offsets, sizes, strides, &outValues);
2375  newAttr = DenseElementsAttr::get(resultType, outValues);
2376  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2377  SmallVector<APFloat> outValues;
2378  outValues.reserve(sourceType.getNumElements());
2379  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2380  elems.begin(), counts, offsets, sizes, strides, &outValues);
2381  newAttr = DenseElementsAttr::get(resultType, outValues);
2382  }
2383 
2384  if (newAttr) {
2385  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2386  return success();
2387  }
2388 
2389  return failure();
2390  }
2391 
2392 private:
2393  /// This additionally controls whether the fold happens or not. Users can
2394  /// impose their heuristics in the function.
2396 };
2397 
2398 } // namespace
2399 
2401  RewritePatternSet &patterns,
2402  const ControlConstantExtractSliceFusionFn &controlFn) {
2403  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2404 }
2405 
2406 /// Return the canonical type of the result of an extract_slice op.
2408  RankedTensorType operator()(ExtractSliceOp op,
2409  ArrayRef<OpFoldResult> mixedOffsets,
2410  ArrayRef<OpFoldResult> mixedSizes,
2411  ArrayRef<OpFoldResult> mixedStrides) {
2412  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2413  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2414  mixedStrides);
2415  }
2416 };
2417 
2418 /// A canonicalizer wrapper to replace ExtractSliceOps.
2420  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2421  ExtractSliceOp newOp) {
2422  Value replacement = newOp.getResult();
2423  if (replacement.getType() != op.getType())
2424  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2425  replacement);
2426  rewriter.replaceOp(op, replacement);
2427  }
2428 };
2429 
2430 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2431  MLIRContext *context) {
2432  results.add<
2435  ExtractSliceOpCastFolder>(context);
2436 }
2437 
2438 //
2439 static LogicalResult
2440 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2441  ShapedType shapedType) {
2442  OpBuilder b(op.getContext());
2443  for (OpFoldResult ofr : op.getMixedOffsets())
2444  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2445  return failure();
2446  // Rank-reducing noops only need to inspect the leading dimensions:
2447  // llvm::zip is appropriate.
2448  auto shape = shapedType.getShape();
2449  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2450  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2451  return failure();
2452  for (OpFoldResult ofr : op.getMixedStrides())
2453  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2454  return failure();
2455  return success();
2456 }
2457 
2458 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2459 /// slice, we can return the InsertSliceOp's source directly.
2460 // TODO: This only checks the immediate producer; extend to go up the
2461 // insert/extract chain if the slices are disjoint.
2462 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2463  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2464 
2465  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2466  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2467  insertOp.isSameAs(extractOp, isSame))
2468  return insertOp.getSource();
2469 
2470  return {};
2471 }
2472 
2473 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2474  if (OpFoldResult reshapedSource = reshapeConstantSource(
2475  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2476  getResult().getType()))
2477  return reshapedSource;
2478  if (getSourceType() == getType() &&
2480  return this->getSource();
2481  if (Value slice = foldExtractAfterInsertSlice(*this))
2482  return slice;
2483 
2484  return OpFoldResult();
2485 }
2486 
2488  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2489  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2490  unsigned rank = rankedTensorType.getRank();
2491  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2492  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2493  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2494  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2495  offsets, sizes, strides);
2496 }
2497 
2498 //===----------------------------------------------------------------------===//
2499 // InsertSliceOp
2500 //===----------------------------------------------------------------------===//
2501 
2502 void InsertSliceOp::getAsmResultNames(
2503  function_ref<void(Value, StringRef)> setNameFn) {
2504  setNameFn(getResult(), "inserted_slice");
2505 }
2506 
2507 // Build a InsertSliceOp with mixed static and dynamic entries.
2508 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2509  Value dest, ArrayRef<OpFoldResult> offsets,
2510  ArrayRef<OpFoldResult> sizes,
2511  ArrayRef<OpFoldResult> strides,
2512  ArrayRef<NamedAttribute> attrs) {
2513  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2514  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2515  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2516  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2517  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2518  result.addAttributes(attrs);
2519  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2520  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2521  b.getDenseI64ArrayAttr(staticSizes),
2522  b.getDenseI64ArrayAttr(staticStrides));
2523 }
2524 
2525 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2526 /// Range vector.
2527 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2528  Value dest, ArrayRef<Range> ranges,
2529  ArrayRef<NamedAttribute> attrs) {
2530  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2531  build(b, result, source, dest, offsets, sizes, strides, attrs);
2532 }
2533 
2534 // Build a InsertSliceOp with dynamic entries.
2535 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2536  Value dest, ValueRange offsets, ValueRange sizes,
2537  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2538  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2539  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2540  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2541  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2542  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2543  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2544  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2545 }
2546 
2547 /// Rank-reducing type verification for both InsertSliceOp and
2548 /// ParallelInsertSliceOp.
2550  RankedTensorType srcType, RankedTensorType dstType,
2551  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2552  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2553  // insert_slice is the inverse of extract_slice, use the same type
2554  // inference.
2555  RankedTensorType expected = ExtractSliceOp::inferResultType(
2556  dstType, staticOffsets, staticSizes, staticStrides);
2557  if (expectedType)
2558  *expectedType = expected;
2559  return isRankReducedType(expected, srcType);
2560 }
2561 
2562 /// Verifier for InsertSliceOp.
2563 LogicalResult InsertSliceOp::verify() {
2564  RankedTensorType expectedType;
2565  SliceVerificationResult result =
2566  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2567  getStaticSizes(), getStaticStrides(), &expectedType);
2568  return produceSliceErrorMsg(result, *this, expectedType);
2569 }
2570 
2571 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2572 /// can mutate the second InsertSliceOp's destination to the first one's.
2573 ///
2574 /// Example:
2575 ///
2576 /// ```mlir
2577 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2578 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2579 /// ```
2580 ///
2581 /// folds into:
2582 ///
2583 /// ```mlir
2584 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2585 /// ```
2586 ///
2587 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2588 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2589  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2590 
2591  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2592  if (!prevInsertOp ||
2593  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2594  !prevInsertOp.isSameAs(insertOp, isSame))
2595  return failure();
2596 
2597  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2598  return success();
2599 }
2600 
2601 /// Folds round-trip extract/insert slice op pairs.
2602 /// Example:
2603 /// ```mlir
2604 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2605 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2606 /// ```
2607 /// can be folded into %val.
2608 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2609  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2610 
2611  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2612  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2613  !extractOp.isSameAs(insertOp, isSame))
2614  return nullptr;
2615 
2616  return extractOp.getSource();
2617 }
2618 
2619 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2620  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2621  getSourceType() == getType() &&
2623  return this->getSource();
2624  if (succeeded(foldInsertAfterInsertSlice(*this)))
2625  return getResult();
2626  if (auto result = foldInsertAfterExtractSlice(*this))
2627  return result;
2628  if (llvm::any_of(getMixedSizes(),
2629  [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2630  return getDest();
2631  return OpFoldResult();
2632 }
2633 
2634 LogicalResult InsertSliceOp::reifyResultShapes(
2635  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2636  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2637  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2638  return success();
2639 }
2640 
2641 namespace {
2642 /// Pattern to rewrite a insert_slice op with constant arguments.
2643 ///
2644 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2645 template <typename InsertOpTy>
2646 class InsertSliceOpConstantArgumentFolder final
2647  : public OpRewritePattern<InsertOpTy> {
2648 public:
2650 
2651  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2652  PatternRewriter &rewriter) const override {
2653  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2654  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2655  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2656 
2657  // No constant operands were folded, just return;
2658  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2659  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2660  failed(foldDynamicStrideList(mixedStrides)))
2661  return failure();
2662 
2663  // Create the new op in canonical form.
2664  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2665  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2666  mixedOffsets, mixedSizes, mixedStrides);
2667  Value toInsert = insertSliceOp.getSource();
2668  if (sourceType != insertSliceOp.getSourceType()) {
2669  OpBuilder::InsertionGuard g(rewriter);
2670  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2671  // is that the insertion point is just before the ParallelCombiningOp in
2672  // the parallel case.
2673  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2674  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2675  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2676  sourceType, toInsert);
2677  }
2678  rewriter.replaceOpWithNewOp<InsertOpTy>(
2679  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2680  mixedSizes, mixedStrides);
2681  return success();
2682  }
2683 };
2684 
2685 /// Fold tensor_casts with insert_slice operations. If the source or
2686 /// destination tensor is a tensor_cast that removes static type information,
2687 /// the cast is folded into the insert_slice operation. E.g.:
2688 ///
2689 /// ```mlir
2690 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2691 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2692 /// ```
2693 ///
2694 /// folds into:
2695 ///
2696 /// ```mlir
2697 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2698 /// ```
2699 ///
2700 /// Note: When folding a cast on the destination tensor, the result of the
2701 /// insert_slice operation is casted to ensure that the type of the result did
2702 /// not change.
2703 ///
2704 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2705 template <typename InsertOpTy>
2706 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2708 
2709  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2710  PatternRewriter &rewriter) const override {
2711  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2712  return matchPattern(operand, matchConstantIndex());
2713  }))
2714  return failure();
2715 
2716  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2717  auto castOp = v.getDefiningOp<tensor::CastOp>();
2718  if (!castOp || !canFoldIntoConsumerOp(castOp))
2719  return std::nullopt;
2720  return castOp.getSource();
2721  };
2722  std::optional<Value> sourceCastSource =
2723  getSourceOfCastOp(insertSliceOp.getSource());
2724  std::optional<Value> destCastSource =
2725  getSourceOfCastOp(insertSliceOp.getDest());
2726  if (!sourceCastSource && !destCastSource)
2727  return failure();
2728 
2729  auto src =
2730  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2731  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2732  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2733  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2734  if (!srcType || !dstType)
2735  return failure();
2736 
2737  // The tensor.cast source could have additional static information not seen
2738  // in the insert slice op static sizes, so we ignore dynamic dims when
2739  // computing the rank reduction mask.
2740  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2741  auto rankReductionMask = computeRankReductionMask(
2742  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2743  if (!rankReductionMask.has_value())
2744  return failure();
2745  // Replace dimensions in the insert slice op with corresponding static dims
2746  // from the cast source type. If the insert slice sizes have static dims
2747  // that are not static in the tensor.cast source (i.e., when the cast op
2748  // casts a dynamic dim to static), the dim should not be replaced, and the
2749  // pattern will fail later in `verifyInsertSliceOp`.
2750  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2751  int64_t rankReducedIdx = 0;
2752  for (auto [idx, size] : enumerate(staticSizes)) {
2753  if (!rankReductionMask.value().contains(idx) &&
2754  !srcType.isDynamicDim(rankReducedIdx)) {
2755  mixedSizes[idx] = getAsIndexOpFoldResult(
2756  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2757  size = srcType.getDimSize(rankReducedIdx++);
2758  }
2759  }
2760  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2761  staticSizes, insertSliceOp.getStaticStrides()) !=
2763  return failure();
2764 
2765  Operation *replacement = rewriter.create<InsertOpTy>(
2766  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2767  mixedSizes, insertSliceOp.getMixedStrides());
2768 
2769  // In the parallel case there is no result and so nothing to cast.
2770  bool isParallelInsert =
2771  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2772  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2773  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2774  insertSliceOp.getDestType(),
2775  replacement->getResult(0));
2776  }
2777  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2778  return success();
2779  }
2780 };
2781 
2782 /// If additional static type information can be deduced from a insert_slice's
2783 /// size operands, insert an explicit cast of the op's source operand. This
2784 /// enables other canonicalization patterns that are matching for tensor_cast
2785 /// ops such as `ForOpTensorCastFolder` in SCF.
2786 ///
2787 /// Example:
2788 ///
2789 /// ```mlir
2790 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2791 /// : tensor<?x?xf32> into ...
2792 /// ```
2793 ///
2794 /// folds into:
2795 ///
2796 /// ```mlir
2797 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2798 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2799 /// : tensor<64x64xf32> into ...
2800 /// ```
2801 ///
2802 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2803 template <typename InsertOpTy>
2804 struct InsertSliceOpSourceCastInserter final
2805  : public OpRewritePattern<InsertOpTy> {
2807 
2808  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2809  PatternRewriter &rewriter) const override {
2810  RankedTensorType srcType = insertSliceOp.getSourceType();
2811  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2812  return failure();
2813  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2814  srcType.getShape().end());
2815  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2816  if (std::optional<int64_t> constInt =
2817  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2818  // Bail on invalid IR.
2819  if (*constInt < 0)
2820  return failure();
2821  newSrcShape[i] = *constInt;
2822  }
2823  }
2824  if (!hasValidSizesOffsets(newSrcShape))
2825  return failure();
2826 
2827  RankedTensorType newSrcType = RankedTensorType::get(
2828  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2829  if (srcType == newSrcType ||
2830  !preservesStaticInformation(srcType, newSrcType) ||
2831  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2832  return failure();
2833 
2834  // newSrcType is:
2835  // 1) Different from srcType.
2836  // 2) "More static" than srcType.
2837  // 3) Cast-compatible with srcType.
2838  // Insert the cast.
2839  OpBuilder::InsertionGuard g(rewriter);
2840  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2841  // that the insertion point is just before the ParallelCombiningOp in the
2842  // parallel case.
2843  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2844  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2845  Value cast = rewriter.create<tensor::CastOp>(
2846  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2847  rewriter.replaceOpWithNewOp<InsertOpTy>(
2848  insertSliceOp, cast, insertSliceOp.getDest(),
2849  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2850  insertSliceOp.getMixedStrides());
2851  return success();
2852  }
2853 };
2854 } // namespace
2855 
2856 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2857  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2858 }
2859 
2860 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2861  MLIRContext *context) {
2862  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2863  InsertSliceOpCastFolder<InsertSliceOp>,
2864  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2865 }
2866 
2868  Location loc,
2869  Value tensor,
2870  Value dest) {
2871  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2872  unsigned rank = rankedTensorType.getRank();
2873  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2874  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2875  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2876  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2877  sizes, strides);
2878 }
2879 
2880 //===----------------------------------------------------------------------===//
2881 // PadOp
2882 //===----------------------------------------------------------------------===//
2883 
2884 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2885  setNameFn(getResult(), "padded");
2886 }
2887 
2888 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2889 // supports optional types.
2890 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2891  Type typeToInfer, Type typeToInferFrom) {}
2892 
2893 ParseResult
2895  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2896  Type &typeToInfer, Type typeToInferFrom) {
2897  if (optOperand)
2898  typeToInfer = typeToInferFrom;
2899  return success();
2900 }
2901 
2902 LogicalResult PadOp::verify() {
2903  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2904  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2905  auto expectedType =
2906  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2907  if (!expectedType) {
2908  return emitError("failed to infer expectedType from sourceType ")
2909  << sourceType << ", specified resultType is " << resultType;
2910  }
2911  if (resultType.getRank() != expectedType.getRank()) {
2912  return emitError("specified type ")
2913  << resultType << " does not match the inferred type "
2914  << expectedType;
2915  }
2916  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2917  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2918  continue;
2919  if (expectedType.isDynamicDim(i))
2920  continue;
2921  return emitError("specified type ")
2922  << resultType << " does not match the inferred type "
2923  << expectedType;
2924  }
2925 
2926  return success();
2927 }
2928 
2929 LogicalResult PadOp::verifyRegions() {
2930  auto &region = getRegion();
2931  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2932  Block &block = region.front();
2933  if (block.getNumArguments() != rank)
2934  return emitError("expected the block to have ") << rank << " arguments";
2935 
2936  // Note: the number and type of yield values are checked in the YieldOp.
2937  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2938  if (!en.value().isIndex())
2939  return emitOpError("expected block argument ")
2940  << (en.index() + 1) << " to be an index";
2941  }
2942 
2943  // Ensure that the region yields an element of the right type.
2944  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2945  if (yieldOp.getValue().getType() !=
2946  llvm::cast<ShapedType>(getType()).getElementType())
2947  return emitOpError("expected yield type to match shape element type");
2948 
2949  return success();
2950 }
2951 
2952 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2953  ArrayRef<int64_t> staticLow,
2954  ArrayRef<int64_t> staticHigh,
2955  ArrayRef<int64_t> resultShape) {
2956  unsigned rank = sourceType.getRank();
2957  if (staticLow.size() != rank)
2958  return RankedTensorType();
2959  if (staticHigh.size() != rank)
2960  return RankedTensorType();
2961  if (!resultShape.empty() && resultShape.size() != rank)
2962  return RankedTensorType();
2963 
2964  SmallVector<int64_t, 4> inferredShape;
2965  for (auto i : llvm::seq<unsigned>(0, rank)) {
2966  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2967  staticHigh[i] == ShapedType::kDynamic) {
2968  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2969  : resultShape[i]);
2970  } else {
2971  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2972  assert((resultShape.empty() || size == resultShape[i] ||
2973  resultShape[i] == ShapedType::kDynamic) &&
2974  "mismatch between inferred shape and result shape");
2975  inferredShape.push_back(size);
2976  }
2977  }
2978 
2979  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2980 }
2981 
2982 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2983  Value source, ArrayRef<int64_t> staticLow,
2984  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2985  bool nofold, ArrayRef<NamedAttribute> attrs) {
2986  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2987  if (!resultType)
2988  resultType = inferResultType(sourceType, staticLow, staticHigh);
2989  result.addAttributes(attrs);
2990  build(b, result, resultType, source, low, high,
2991  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2992  nofold ? b.getUnitAttr() : UnitAttr());
2993 }
2994 
2995 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2996  Value source, ValueRange low, ValueRange high, bool nofold,
2997  ArrayRef<NamedAttribute> attrs) {
2998  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2999  unsigned rank = sourceType.getRank();
3000  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3001  build(b, result, resultType, source, staticVector, staticVector, low, high,
3002  nofold, attrs);
3003 }
3004 
3005 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3006  Value source, ArrayRef<OpFoldResult> low,
3007  ArrayRef<OpFoldResult> high, bool nofold,
3008  ArrayRef<NamedAttribute> attrs) {
3009  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3010  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3011  SmallVector<int64_t, 4> staticLow, staticHigh;
3012  // staticLow and staticHigh have full information of the padding config.
3013  // This will grow staticLow and staticHigh with 1 value. If the config is
3014  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3015  // value as well.
3016  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3017  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3018  if (!resultType) {
3019  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3020  }
3021  assert(llvm::isa<RankedTensorType>(resultType));
3022  result.addAttributes(attrs);
3023  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3024  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3025  nofold ? b.getUnitAttr() : UnitAttr());
3026 }
3027 
3028 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3029  Value source, ArrayRef<OpFoldResult> low,
3030  ArrayRef<OpFoldResult> high, Value constantPadValue,
3031  bool nofold, ArrayRef<NamedAttribute> attrs) {
3032  build(b, result, resultType, source, low, high, nofold, attrs);
3033 
3034  // Add a region and a block to yield the pad value.
3035  Region *region = result.regions[0].get();
3036  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3037  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3038  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3039 
3040  // `builder.createBlock` changes the insertion point within the block. Create
3041  // a guard to reset the insertion point of the builder after it is destroyed.
3042  OpBuilder::InsertionGuard guard(b);
3043  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3044  b.create<tensor::YieldOp>(result.location, constantPadValue);
3045 }
3046 
3047 llvm::SmallBitVector PadOp::getPaddedDims() {
3048  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3049  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3050  for (const auto &en : enumerate(paddingWidths))
3051  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3052  paddedDims.set(en.index());
3053  };
3054  extractPaddedDims(getMixedLowPad());
3055  extractPaddedDims(getMixedHighPad());
3056  return paddedDims;
3057 }
3058 
3059 namespace {
3060 // Folds tensor.pad when padding is static zeros and the attribute
3061 // doesn't request otherwise.
3062 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3064 
3065  LogicalResult matchAndRewrite(PadOp padTensorOp,
3066  PatternRewriter &rewriter) const override {
3067  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3068  return failure();
3069  if (padTensorOp.getNofold())
3070  return failure();
3071  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3072  padTensorOp, padTensorOp.getResult().getType(),
3073  padTensorOp.getSource());
3074  return success();
3075  }
3076 };
3077 
3078 // Fold CastOp into PadOp when adding static information.
3079 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3081 
3082  LogicalResult matchAndRewrite(PadOp padTensorOp,
3083  PatternRewriter &rewriter) const override {
3084  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3085  if (!tensor::canFoldIntoConsumerOp(castOp))
3086  return failure();
3087 
3088  auto newResultType = PadOp::inferResultType(
3089  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3090  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3091  padTensorOp.getResultType().getShape());
3092 
3093  if (newResultType == padTensorOp.getResultType()) {
3094  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3095  padTensorOp.getSourceMutable().assign(castOp.getSource());
3096  });
3097  } else {
3098  auto newOp = rewriter.create<PadOp>(
3099  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3100  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3101  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3102  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3103  IRMapping mapper;
3104  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3105 
3106  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3107  padTensorOp, padTensorOp.getResultType(), newOp);
3108  }
3109  return success();
3110  }
3111 };
3112 
3113 // Fold CastOp using the result of PadOp back into the latter if it adds
3114 // static information.
3115 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3117 
3118  LogicalResult matchAndRewrite(PadOp padTensorOp,
3119  PatternRewriter &rewriter) const override {
3120  if (!padTensorOp.getResult().hasOneUse())
3121  return failure();
3122  auto tensorCastOp =
3123  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3124  if (!tensorCastOp)
3125  return failure();
3126  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3127  tensorCastOp.getDest().getType()))
3128  return failure();
3129 
3130  auto replacementOp = rewriter.create<PadOp>(
3131  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3132  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3133  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3134  padTensorOp.getHigh(), padTensorOp.getNofold(),
3135  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3136  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3137 
3138  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3139  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3140  return success();
3141  }
3142 };
3143 
3144 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3145 /// different dimensions. The pattern applies if the following preconditions
3146 /// hold:
3147 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3148 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3149 /// 3) the tensor::PadOps perform only high-padding,
3150 /// 4) the tensor::PadOps have the same constant padding value,
3151 /// 5) the tensor::PadOps do not have common padding dimensions,
3152 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3153 /// zero-offset for every dimension.
3154 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3155 /// the
3156 /// padded source dimensions.
3157 ///
3158 /// Example:
3159 ///
3160 /// ```mlir
3161 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3162 /// : tensor<64x64xf32> to tensor<?x64xf32>
3163 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3164 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3165 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3166 /// : tensor<8x64xf32> to tensor<8x?xf32>
3167 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3168 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3169 /// ```
3170 ///
3171 /// folds into:
3172 ///
3173 /// ```mlir
3174 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3175 /// : tensor<64x64xf32> to tensor<?x?xf32>
3176 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3177 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3178 /// ```
3179 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3181 
3182  LogicalResult matchAndRewrite(PadOp padOp,
3183  PatternRewriter &rewriter) const override {
3184  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3185  if (!innerSliceOp)
3186  return failure();
3187  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3188  if (!outerPadOp || outerPadOp.getNofold())
3189  return failure();
3190  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3191  if (!outerSliceOp)
3192  return failure();
3193 
3194  // 1) Fail if the chain is rank-reducing.
3195  int64_t rank = padOp.getSourceType().getRank();
3196  if (outerSliceOp.getSourceType().getRank() != rank) {
3197  return rewriter.notifyMatchFailure(padOp,
3198  "cannot fold rank-reducing chain");
3199  }
3200 
3201  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3202  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3203  return rewriter.notifyMatchFailure(
3204  padOp, "cannot fold non-unit stride ExtractSliceOps");
3205  }
3206 
3207  // 3) Fail if the tensor::PadOps have non-zero low padding.
3208  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3209  return rewriter.notifyMatchFailure(padOp,
3210  "cannot fold PadOps with low padding");
3211  }
3212 
3213  // 4) Fail if the tensor::PadOps padding values do not match.
3214  Attribute innerAttr, outerAttr;
3215  Value innerValue = padOp.getConstantPaddingValue();
3216  Value outerValue = outerPadOp.getConstantPaddingValue();
3217  if (!innerValue || !outerValue ||
3218  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3219  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3220  innerAttr != outerAttr) {
3221  return rewriter.notifyMatchFailure(
3222  padOp, "cannot fold PadOps with different padding values");
3223  }
3224 
3225  // 5) Fail if a dimension is padded by both tensor::PadOps.
3226  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3227  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3228  if (innerDims.anyCommon(outerDims)) {
3229  return rewriter.notifyMatchFailure(
3230  padOp, "cannot fold PadOps with common padding dimensions");
3231  }
3232 
3233  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3234  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3235  // for every dimension, and use the offset the other pair. Fail if no
3236  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3237  // exists.
3238  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3239  for (auto en : enumerate(newOffsets)) {
3240  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3241  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3242  if (!innerDims.test(en.index()) &&
3243  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3244  en.value() = outerOffset;
3245  continue;
3246  }
3247  if (!outerDims.test(en.index()) &&
3248  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3249  en.value() = innerOffset;
3250  continue;
3251  }
3252  return rewriter.notifyMatchFailure(
3253  padOp, "cannot find zero-offset and zero-padding pair");
3254  }
3255 
3256  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3257  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3258  // outer tensor::PadOp and fail if the size of the inner
3259  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3260  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3261  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3262  for (auto en : enumerate(newSizes)) {
3263  if (!outerDims.test(en.index()))
3264  continue;
3265  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3266  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3267  assert(!ShapedType::isDynamic(sourceSize) &&
3268  "expected padded dimension to have a static size");
3269  if (getConstantIntValue(sliceSize) != sourceSize) {
3270  return rewriter.notifyMatchFailure(
3271  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3272  "match the size of the outer padding");
3273  }
3274  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3275  }
3276 
3277  // Combine the high paddings of the two tensor::PadOps.
3278  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3279  for (auto en : enumerate(newHighPad)) {
3280  if (innerDims.test(en.index()))
3281  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3282  if (outerDims.test(en.index()))
3283  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3284  }
3285 
3286  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3287  // the two paddings in one step.
3288  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3289  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3290  innerSliceOp.getMixedStrides());
3291  auto newPadOp = rewriter.create<PadOp>(
3292  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3293  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3294  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3295  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3296  newPadOp.getRegion().begin());
3297  rewriter.replaceOp(padOp, newPadOp.getResult());
3298  return success();
3299  }
3300 };
3301 
3302 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3304 
3305  LogicalResult matchAndRewrite(PadOp padTensorOp,
3306  PatternRewriter &rewriter) const override {
3307  Value input = padTensorOp.getSource();
3308  if (!llvm::isa<RankedTensorType>(input.getType()))
3309  return failure();
3310  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3311  auto inputRank = inputDims.size();
3312 
3313  auto oldResultType =
3314  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3315  if (!oldResultType)
3316  return failure();
3317 
3318  auto outputDims = oldResultType.getShape();
3319 
3320  // Extract the static info from the high and low operands.
3321  SmallVector<int64_t> constOperandsLow;
3322  SmallVector<Value> newLows;
3323  for (auto operand : padTensorOp.getLow()) {
3324  APSInt intOp;
3325  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3326  constOperandsLow.push_back(ShapedType::kDynamic);
3327  newLows.push_back(operand);
3328  continue;
3329  }
3330  constOperandsLow.push_back(intOp.getExtValue());
3331  }
3332  SmallVector<int64_t> constOperandsHigh;
3333  SmallVector<Value> newHighs;
3334  for (auto operand : padTensorOp.getHigh()) {
3335  APSInt intOp;
3336  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3337  constOperandsHigh.push_back(ShapedType::kDynamic);
3338  newHighs.push_back(operand);
3339  continue;
3340  }
3341  constOperandsHigh.push_back(intOp.getExtValue());
3342  }
3343 
3344  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3345  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3346 
3347  // Verify the op is well-formed.
3348  if (inputDims.size() != outputDims.size() ||
3349  inputDims.size() != constLow.size() ||
3350  inputDims.size() != constHigh.size())
3351  return failure();
3352 
3353  auto lowCount = 0;
3354  auto highCount = 0;
3355  for (size_t i = 0; i < inputRank; i++) {
3356  if (constLow[i] == ShapedType::kDynamic)
3357  constLow[i] = constOperandsLow[lowCount++];
3358  if (constHigh[i] == ShapedType::kDynamic)
3359  constHigh[i] = constOperandsHigh[highCount++];
3360  }
3361 
3362  auto staticLow = ArrayRef<int64_t>(constLow);
3363  auto staticHigh = ArrayRef<int64_t>(constHigh);
3364 
3365  // Calculate the output sizes with the static information.
3366  SmallVector<int64_t> newOutDims;
3367  for (size_t i = 0; i < inputRank; i++) {
3368  if (outputDims[i] == ShapedType::kDynamic) {
3369  newOutDims.push_back(
3370  (staticLow[i] == ShapedType::kDynamic ||
3371  staticHigh[i] == ShapedType::kDynamic ||
3372  inputDims[i] == ShapedType::kDynamic
3373  ? ShapedType::kDynamic
3374  : inputDims[i] + staticLow[i] + staticHigh[i]));
3375  } else {
3376  newOutDims.push_back(outputDims[i]);
3377  }
3378  }
3379 
3380  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3381  llvm::all_of(newOutDims,
3382  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3383  return failure();
3384 
3385  // Rewrite the op using the new static type.
3386  auto newResultType = RankedTensorType::get(
3387  newOutDims, padTensorOp.getType().getElementType());
3388  auto newOp = rewriter.create<PadOp>(
3389  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3390  newLows, newHighs, padTensorOp.getNofold(),
3391  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3392 
3393  IRMapping mapper;
3394  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3395  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3396  newOp);
3397 
3398  return success();
3399  }
3400 };
3401 
3402 } // namespace
3403 
3404 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3405  MLIRContext *context) {
3406  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3407  FoldOrthogonalPaddings, FoldStaticPadding>(context);
3408 }
3409 
3410 /// Return the padding value of the PadOp if it constant. In this context,
3411 /// "constant" means an actual constant or "defined outside of the block".
3412 ///
3413 /// Values are considered constant in three cases:
3414 /// - A ConstantLike value.
3415 /// - A basic block argument from a different block.
3416 /// - A value defined outside of the block.
3417 ///
3418 /// If the padding value is not constant, an empty Value is returned.
3419 Value PadOp::getConstantPaddingValue() {
3420  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3421  if (!yieldOp)
3422  return {};
3423  Value padValue = yieldOp.getValue();
3424  // Check if yield value is a constant.
3425  if (matchPattern(padValue, m_Constant()))
3426  return padValue;
3427  // Check if yield value is defined inside the PadOp block.
3428  if (padValue.getParentBlock() == &getRegion().front())
3429  return {};
3430  // Else: Yield value defined outside of the PadOp block.
3431  return padValue;
3432 }
3433 
3434 OpFoldResult PadOp::fold(FoldAdaptor) {
3435  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3436  !getNofold())
3437  return getSource();
3438  return {};
3439 }
3440 
3441 //===----------------------------------------------------------------------===//
3442 // ParallelInsertSliceOp
3443 //===----------------------------------------------------------------------===//
3444 
3445 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3446  ParallelCombiningOpInterface parallelCombiningParent =
3447  getParallelCombiningParent();
3448  for (const auto &it :
3449  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3450  Operation &nextOp = it.value();
3451  if (&nextOp == getOperation())
3452  return parallelCombiningParent.getParentResult(it.index());
3453  }
3454  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3455 }
3456 
3457 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3458 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3459  Value source, Value dest,
3460  ArrayRef<OpFoldResult> offsets,
3461  ArrayRef<OpFoldResult> sizes,
3462  ArrayRef<OpFoldResult> strides,
3463  ArrayRef<NamedAttribute> attrs) {
3464  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3465  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3466  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3467  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3468  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3469  result.addAttributes(attrs);
3470  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3471  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3472  b.getDenseI64ArrayAttr(staticSizes),
3473  b.getDenseI64ArrayAttr(staticStrides));
3474 }
3475 
3476 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3477 /// packed into a Range vector.
3478 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3479  Value source, Value dest,
3480  ArrayRef<Range> ranges,
3481  ArrayRef<NamedAttribute> attrs) {
3482  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3483  build(b, result, source, dest, offsets, sizes, strides, attrs);
3484 }
3485 
3486 // Build a ParallelInsertSliceOp with dynamic entries.
3487 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3488  Value source, Value dest, ValueRange offsets,
3489  ValueRange sizes, ValueRange strides,
3490  ArrayRef<NamedAttribute> attrs) {
3491  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3492  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3493  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3494  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3495  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3496  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3497  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3498 }
3499 
3500 LogicalResult ParallelInsertSliceOp::verify() {
3501  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3502  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3503  << *(getOperation()->getParentOp());
3504 
3505  RankedTensorType expectedType;
3506  SliceVerificationResult result =
3507  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3508  getStaticSizes(), getStaticStrides(), &expectedType);
3509  return produceSliceErrorMsg(result, *this, expectedType);
3510 }
3511 
3512 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3513  RewritePatternSet &results, MLIRContext *context) {
3514  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3515  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3516  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3517 }
3518 
3519 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3520  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3521 }
3522 
3523 //===----------------------------------------------------------------------===//
3524 // ScatterOp
3525 //===----------------------------------------------------------------------===//
3526 
3527 void ScatterOp::getAsmResultNames(
3528  function_ref<void(Value, StringRef)> setNameFn) {
3529  setNameFn(getResult(), "scatter");
3530 }
3531 
3532 LogicalResult ScatterOp::verify() {
3533  int64_t destRank = getDestType().getRank();
3534  ArrayRef<int64_t> scatterDims = getScatterDims();
3535  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3536  "scatter", "dest")))
3537  return failure();
3538 
3539  if (!getUnique())
3540  return emitOpError("requires 'unique' attribute to be set");
3541  // TODO: we could also check statically that there are fewer leading index
3542  // tensor dims than the dest dims. If this is not the case, the unique
3543  // attribute cannot be true.
3544 
3545  // Use the GatherOp::inferResultType on the `dest` type and verify the
3546  // expected type matches the source type.
3547  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3548  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3549  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3550  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3551  if (getSourceType() != expectedSourceType &&
3552  getSourceType() != expectedRankReducedSourceType) {
3553  return emitOpError("source type "
3554  "mismatch: "
3555  "expected ")
3556  << expectedSourceType << " or its rank-reduced variant "
3557  << expectedRankReducedSourceType << " (got: " << getSourceType()
3558  << ")";
3559  }
3560 
3561  return success();
3562 }
3563 
3564 //===----------------------------------------------------------------------===//
3565 // SplatOp
3566 //===----------------------------------------------------------------------===//
3567 
3568 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3569  Type aggregateType, ValueRange dynamicSizes) {
3570  build(builder, result, aggregateType, element, dynamicSizes);
3571 }
3572 
3573 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3574  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3575  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3576  build(builder, result, aggregateType, element, dynamicSizes);
3577 }
3578 
3579 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3580  ArrayRef<OpFoldResult> sizes) {
3581  SmallVector<int64_t> staticShape;
3582  SmallVector<Value> dynamicSizes;
3583  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3584  build(builder, result, element, staticShape, dynamicSizes);
3585 }
3586 
3587 void SplatOp::getAsmResultNames(
3588  function_ref<void(Value, StringRef)> setNameFn) {
3589  setNameFn(getResult(), "splat");
3590 }
3591 
3592 LogicalResult SplatOp::verify() {
3593  if (getType().getNumDynamicDims() !=
3594  static_cast<int64_t>(getDynamicSizes().size()))
3595  return emitOpError("incorrect number of dynamic sizes, has ")
3596  << getDynamicSizes().size() << ", expected "
3597  << getType().getNumDynamicDims();
3598  return success();
3599 }
3600 
3601 LogicalResult
3603  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3604  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3605  unsigned ctr = 0;
3606  for (int64_t i = 0; i < getType().getRank(); ++i) {
3607  if (getType().isDynamicDim(i)) {
3608  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3609  } else {
3610  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3611  }
3612  }
3613  return success();
3614 }
3615 
3616 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3617  auto constOperand = adaptor.getInput();
3618  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3619  return {};
3620 
3621  // Do not fold if the splat is not statically shaped
3622  if (!getType().hasStaticShape())
3623  return {};
3624 
3625  // SplatElementsAttr::get treats single value for second arg as being a
3626  // splat.
3627  return SplatElementsAttr::get(getType(), {constOperand});
3628 }
3629 
3630 //===----------------------------------------------------------------------===//
3631 // PackOp/UnPackOp Common
3632 //===----------------------------------------------------------------------===//
3633 
3634 template <typename OpTy>
3635 static LogicalResult
3637  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3638  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3639  "applies to only pack or unpack operations");
3640  int64_t destRank = op.getDestRank();
3641  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3642  reifiedReturnShapes[0] =
3643  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3644  return success();
3645 }
3646 
3647 template <typename OpTy>
3649  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3650  "applies to only pack or unpack operations");
3651  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3652  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3653  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3654  assert(tiles.size() == dimsToTile.size() &&
3655  "tiles must match indices of dimension to block");
3656  // bind the dimension `i` with the tile factor.
3657  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3658  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3659  return dimAndTileMapping;
3660 }
3661 
3662 template <typename OpTy>
3664  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3665  "applies to only pack or unpack operations");
3666  Builder builder(op);
3667  SmallVector<OpFoldResult> mixedInnerTiles;
3668  unsigned dynamicValIndex = 0;
3669  for (int64_t staticTile : op.getStaticInnerTiles()) {
3670  if (!ShapedType::isDynamic(staticTile))
3671  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3672  else
3673  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3674  }
3675  return mixedInnerTiles;
3676 }
3677 
3678 template <typename OpTy>
3680  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3681  "applies to only pack or unpack operations");
3682  SmallVector<Value> dynamicTiles;
3683  SmallVector<int64_t> staticTiles;
3684  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3685  return staticTiles;
3686 }
3687 
3688 /// Returns true if `dimsPos` is invalid. It is invalid when:
3689 /// a) It contains duplicate.
3690 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3691 /// c) The number of elements in `dimsPos` is > than `rank`.
3693  size_t rank) {
3694  size_t dimsPosSize = dimsPos.size();
3695  if (dimsPosSize > rank)
3696  return true;
3697  DenseSet<int64_t> uniqued;
3698  for (int64_t dim : dimsPos)
3699  uniqued.insert(dim);
3700  if (dimsPosSize != uniqued.size())
3701  return true;
3702  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3703  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3704  });
3705 }
3706 
3707 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3708 /// of the `limitShape`.
3709 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3710  ArrayRef<int64_t> limitShape) {
3711  assert(
3712  sourceShape.size() == limitShape.size() &&
3713  "expected source shape rank, and limit of the shape to have same rank");
3714  return llvm::all_of(
3715  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3716  int64_t sourceExtent = std::get<0>(it);
3717  int64_t limit = std::get<1>(it);
3718  return ShapedType::isDynamic(sourceExtent) ||
3719  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3720  });
3721 }
3722 
3723 template <typename OpTy>
3724 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
3725  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3726  "applies to only pack or unpack operations");
3727  Operation *op = packOrUnPack.getOperation();
3728 
3729  // Return true if we have a zero-value tile.
3730  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3731  return llvm::any_of(
3732  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3733  };
3734 
3735  // Verify tiles. Do not allow zero tiles.
3736  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3737  if (hasZeros(mixedTiles))
3738  return op->emitError("invalid zero tile factor");
3739 
3740  // Verify inner_dims_pos and outer_dims_perm.
3741  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3742  ? packOrUnPack.getSourceType()
3743  : packOrUnPack.getDestType();
3744  size_t unpackedRank = unpackedType.getRank();
3745  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3746  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3747  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3748  return op->emitError("invalid inner_dims_pos vector");
3749  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3750  return op->emitError("invalid outer_dims_perm vector");
3751  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3752  return op->emitError("outer_dims_perm must be a permutation or empty");
3753 
3754  // Tiling factors must be less than or equal to the input rank for pack (or
3755  // output rank for unpack), and must match the number of `inner_dims_pos`.
3756  if (mixedTiles.size() > unpackedRank) {
3757  return op->emitError("tiling factors must be less than or equal to the "
3758  "input rank for pack or output rank for unpack");
3759  }
3760  if (mixedTiles.size() != innerDimsPos.size()) {
3761  return op->emitError(
3762  "tiling factors must equal the number of dimensions to tile");
3763  }
3764 
3765  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3766  ? packOrUnPack.getDestType()
3767  : packOrUnPack.getSourceType();
3768  size_t packedRank = packedType.getRank();
3769  // Require output rank to match input rank + number of blocking factors.
3770  if (unpackedRank + mixedTiles.size() != packedRank) {
3771  return op->emitError(
3772  "packed rank must equal unpacked rank + tiling factors");
3773  }
3774 
3775  // Verify result shape is greater than the minimum expected
3776  // by the pack operation, and that the output shape
3777  // represents full tiles.
3778  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3779  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3780  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3781  return op->emitError("the shape of output is not large enough to hold the "
3782  "packed data. Expected at least ")
3783  << expectedPackedType << ", got " << packedType;
3784  }
3785  if (!llvm::all_of(
3786  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3787  mixedTiles),
3788  [](std::tuple<int64_t, OpFoldResult> it) {
3789  std::optional<int64_t> constTileSize =
3790  getConstantIntValue(std::get<1>(it));
3791  int64_t shape = std::get<0>(it);
3792  if (!constTileSize) {
3793  // If specified tile size is dynamic, output shape should
3794  // be dynamic too.
3795  return ShapedType::isDynamic(shape);
3796  }
3797  if (ShapedType::isDynamic(shape)) {
3798  // For the shape being dynamic when tile size is
3799  // specified, return true. In canonical form a constant
3800  // tile size should lead to constant shape of the tiled
3801  // dimension, but not needed for verification.
3802  return true;
3803  }
3804  return shape == constTileSize.value();
3805  })) {
3806  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3807  "tiled dimension in the packed type");
3808  }
3809  return success();
3810 }
3811 
3812 namespace {
3813 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3814 /// various permutations to the op.
3815 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3816 // these. These may or may not become true foldings / canonicalizations
3817 // depending on how aggressive we want to be in automatically folding
3818 // transposes.
3819 struct PackOrUnPackTransposeResult {
3820  SmallVector<int64_t> innerDimsPos;
3821  SmallVector<OpFoldResult> innerTiles;
3822  SmallVector<int64_t> outerDimsPerm;
3823 };
3824 } // namespace
3825 
3826 template <typename OpTy>
3827 static PackOrUnPackTransposeResult
3829  ArrayRef<int64_t> innerPermutation,
3830  ArrayRef<int64_t> outerPermutation) {
3831  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3832  "applies to only pack or unpack operations");
3833  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3834  "some permutation must be non-empty");
3835  PackOrUnPackTransposeResult metadata;
3836  metadata.innerDimsPos =
3837  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3838  metadata.innerTiles =
3839  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3840  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3841  ? packOrUnPackOp.getSourceRank()
3842  : packOrUnPackOp.getDestRank();
3843  metadata.outerDimsPerm =
3844  packOrUnPackOp.getOuterDimsPerm().empty()
3845  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3846  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3847  if (!innerPermutation.empty()) {
3848  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3849  isPermutationVector(innerPermutation) &&
3850  "invalid inner permutation");
3851  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3852  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3853  }
3854  if (!outerPermutation.empty()) {
3855  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3856  isPermutationVector(outerPermutation) &&
3857  "invalid outer permutation");
3858  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3859  }
3860  return metadata;
3861 }
3862 
3863 //===----------------------------------------------------------------------===//
3864 // PackOp
3865 //===----------------------------------------------------------------------===//
3866 
3867 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3868  setNameFn(getResult(), "pack");
3869 }
3870 
3871 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3872  Value dest, ArrayRef<int64_t> innerDimsPos,
3873  ArrayRef<OpFoldResult> innerTiles,
3874  std::optional<Value> paddingValue,
3875  ArrayRef<int64_t> outerDimsPerm) {
3876  assert(innerDimsPos.size() == innerTiles.size() &&
3877  "number of tile sizes specified must match the specified number of "
3878  "original dimensions to be tiled");
3879  SmallVector<int64_t> staticTileSizes;
3880  SmallVector<Value> dynamicTileSizes;
3881  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3882  build(builder, state, dest.getType(), source, dest,
3883  paddingValue ? *paddingValue : nullptr,
3884  outerDimsPerm.empty() ? nullptr
3885  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3886  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3887  builder.getDenseI64ArrayAttr(staticTileSizes));
3888 }
3889 
3890 LogicalResult
3892  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3893  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3894 }
3895 
3896 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3897  return getDimAndTileMappingImpl(*this);
3898 }
3899 
3900 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3901  return getMixedTilesImpl(*this);
3902 }
3903 
3904 SmallVector<int64_t> PackOp::getStaticTiles() {
3905  return getStaticTilesImpl(*this);
3906 }
3907 
3908 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3909  ArrayRef<int64_t> innerDimsPos,
3910  ArrayRef<int64_t> outputShape,
3911  ArrayRef<int64_t> outerDimsPerm,
3912  ArrayRef<OpFoldResult> innerTiles) {
3913  SmallVector<int64_t> outputTileSizes(
3914  outputShape.take_front(inputShape.size()));
3915  if (!outerDimsPerm.empty()) {
3916  assert(outerDimsPerm.size() == outputTileSizes.size() &&
3917  "expected output and outer_dims_perm to have same size");
3918  applyPermutationToVector(outputTileSizes,
3919  invertPermutationVector(outerDimsPerm));
3920  }
3921  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3922  if (ShapedType::isDynamic(inputShape[pos]))
3923  continue;
3924  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3925 
3926  if (!constantTile) {
3927  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3928  (inputShape[pos] % outputTileSizes[pos] != 0))
3929  return true;
3930  } else if (inputShape[pos] % (*constantTile) != 0) {
3931  return true;
3932  }
3933  }
3934  return false;
3935 }
3936 
3937 LogicalResult PackOp::verify() {
3938  if (failed(commonVerifierPackAndUnPackOp(*this)))
3939  return failure();
3940 
3941  // Verify padding value, and bail out if the tile does not divide the
3942  // dimension fully. In the case of dynamic tile factors or dimensions, having
3943  // a partial tile is undefined behavior.
3944  auto paddingValue = getPaddingValue();
3945  if (paddingValue &&
3946  paddingValue.getType() != getSourceType().getElementType()) {
3947  return emitOpError("expected padding_value has ")
3948  << getSourceType().getElementType()
3949  << " but got: " << paddingValue.getType();
3950  }
3951 
3952  if (!paddingValue &&
3953  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3954  getDestType().getShape(), getOuterDimsPerm(),
3955  getMixedTiles())) {
3956  return emitOpError(
3957  "invalid tile factor or output size provided. Only full tiles are "
3958  "supported when padding_value is not set");
3959  }
3960  return success();
3961 }
3962 
3963 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3964 /// Value's to kDynamic, even if they are arith.constant values.
3965 static SmallVector<int64_t>
3967  SmallVector<int64_t> result;
3968  for (auto o : ofrs) {
3969  // Have to do this first, as getConstantIntValue special-cases constants.
3970  if (llvm::dyn_cast_if_present<Value>(o))
3971  result.push_back(ShapedType::kDynamic);
3972  else
3973  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3974  }
3975  return result;
3976 }
3977 
3978 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3979 /// the packed type. Having a shared helper helps implement these two methods in
3980 /// a way that ensures that they agree on which dimensions are dynamic.
3982  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3983  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3984  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3985  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
3986  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3987  continue;
3988  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3989  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3990  continue;
3991  }
3992  resultShape[tiledDim.value()] = divideCeilSigned(
3993  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
3994  }
3995 
3996  // Swap tile loops if outer_dims_perm is available.
3997  if (!outerDimsPerm.empty())
3998  applyPermutationToVector(resultShape, outerDimsPerm);
3999 
4000  // Append the inner tile dimensions.
4001  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4002  return resultShape;
4003 }
4004 
4005 SmallVector<OpFoldResult> PackOp::getResultShape(
4006  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4007  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4008  ArrayRef<int64_t> outerDimsPerm) {
4009  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4010 
4011  AffineExpr s0, s1;
4012  bindSymbols(builder.getContext(), s0, s1);
4013  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4014  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4015  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4016  builder, loc, ceilDivExpr,
4017  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4018  }
4019  if (!outerDimsPerm.empty())
4020  applyPermutationToVector(resultDims, outerDimsPerm);
4021  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4022 
4023  SmallVector<int64_t> resultTypeShape =
4025  asShapeWithAnyValueAsDynamic(innerTileSizes),
4026  innerDimsPos, outerDimsPerm);
4027 
4028  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4029  // result type shape says it's a dynamic dim. This is needed as callers may
4030  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4031  // dynamic dims returned by that.
4032  for (unsigned i = 0; i < resultDims.size(); ++i) {
4033  if (!ShapedType::isDynamic(resultTypeShape[i]))
4034  continue;
4035  resultDims[i] =
4036  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4037  }
4038 
4039  return resultDims;
4040 }
4041 
4042 /// Get the expected packed type based on source type, tile factors, position of
4043 /// the inner tiles and permutation of the outer tiled loop.
4044 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4045  ArrayRef<int64_t> innerTileSizes,
4046  ArrayRef<int64_t> innerDimsPos,
4047  ArrayRef<int64_t> outerDimsPerm) {
4049  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4050  return RankedTensorType::get(resultShape, sourceType.getElementType());
4051 }
4052 
4053 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4054  ArrayRef<OpFoldResult> innerTileSizes,
4055  ArrayRef<int64_t> innerDimsPos,
4056  ArrayRef<int64_t> outerDimsPerm) {
4057  AffineExpr dim0, dim1;
4058  bindDims(b.getContext(), dim0, dim1);
4059  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4060  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4061  {v1, v2});
4062  };
4063 
4064  SmallVector<OpFoldResult> mixedSizes;
4065  for (auto [index, value] : llvm::enumerate(
4066  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4067  if (ShapedType::isDynamic(value))
4068  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4069  else
4070  mixedSizes.push_back(b.getIndexAttr(value));
4071  }
4072  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4073  int64_t dimPos = std::get<0>(it);
4074  OpFoldResult tileSize = std::get<1>(it);
4075  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4076  }
4077  if (!outerDimsPerm.empty())
4078  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4079 
4080  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4081  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4082  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4083 }
4084 
4085 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4086  ArrayRef<int64_t> innerPermutation,
4087  ArrayRef<int64_t> outerPermutation) {
4088  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4089  *this, innerPermutation, outerPermutation);
4090  Value transposedDest =
4091  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4092  metadata.innerDimsPos, metadata.outerDimsPerm);
4093  return b.create<PackOp>(loc, getSource(), transposedDest,
4094  metadata.innerDimsPos, metadata.innerTiles,
4095  getPaddingValue(), metadata.outerDimsPerm);
4096 }
4097 
4098 /// Returns true if the tiles and the tiled dims are constant.
4099 template <typename OpTy>
4101  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4102  "applies to only pack or unpack operations");
4103  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4104  ? op.getDestType()
4105  : op.getSourceType();
4106  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4107  for (auto [dimDest, tile] : llvm::zip(
4108  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4109  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4110  if (!constTileSize || ShapedType::isDynamic(dimDest))
4111  return false;
4112  }
4113  return true;
4114 }
4115 
4116 Speculation::Speculatability PackOp::getSpeculatability() {
4117  if (getPaddingValue())
4119 
4120  // The verifier rejects already operations if we can statically prove that the
4121  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4122  // to have constant tiles and tiled inner dimensions.
4123  if (!areTilesAndTiledDimsAllConstant(*this))
4125 
4127 }
4128 
4129 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4130 // dimensions for pack and unpack.
4131 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4132  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4133  return false;
4134  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4135  return true;
4136  // Outer dims permutation is optional.
4137  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4138  // identity permutation.
4139  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4140  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4141 }
4142 
4143 // Return true if pack and unpack have the same tiles.
4144 // Same SSA values or same integer constants.
4145 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4146  auto packTiles = packOp.getMixedTiles();
4147  auto unPackTiles = unPackOp.getMixedTiles();
4148  if (packTiles.size() != unPackTiles.size())
4149  return false;
4150  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4151  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4152  return false;
4153  }
4154  return true;
4155 }
4156 
4157 /// Returns true if the pack op does not need a padding value.
4158 static bool paddingIsNotNeeded(PackOp op) {
4159  auto srcType = op.getSourceType();
4160  if (llvm::any_of(op.getInnerDimsPos(),
4161  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4162  return false;
4163  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4164  return false;
4165  return !PackOp::requirePaddingValue(
4166  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4167  op.getOuterDimsPerm(), op.getMixedTiles());
4168 }
4169 
4170 /// Returns true if the `srcShape` or `destShape` is different from the one in
4171 /// `packOp` and populates each with the inferred static shape.
4172 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4173  SmallVectorImpl<int64_t> &destShape) {
4174  bool changeNeeded = false;
4175  srcShape.assign(packOp.getSourceType().getShape().begin(),
4176  packOp.getSourceType().getShape().end());
4177  destShape.assign(packOp.getDestType().getShape().begin(),
4178  packOp.getDestType().getShape().end());
4179  llvm::SmallSetVector<int64_t, 4> innerDims;
4180  innerDims.insert(packOp.getInnerDimsPos().begin(),
4181  packOp.getInnerDimsPos().end());
4182  SmallVector<int64_t> inverseOuterDimsPerm;
4183  if (!packOp.getOuterDimsPerm().empty())
4184  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4185  int srcRank = packOp.getSourceRank();
4186  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4187  if (innerDims.contains(i))
4188  continue;
4189  int64_t srcPos = i;
4190  int64_t destPos = i;
4191  if (!inverseOuterDimsPerm.empty())
4192  destPos = inverseOuterDimsPerm[srcPos];
4193  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4194  ShapedType::isDynamic(destShape[destPos])) {
4195  continue;
4196  }
4197  int64_t size = srcShape[srcPos];
4198  if (ShapedType::isDynamic(size))
4199  size = destShape[destPos];
4200  srcShape[srcPos] = size;
4201  destShape[destPos] = size;
4202  changeNeeded = true;
4203  }
4204  return changeNeeded;
4205 }
4206 
4207 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4208  // Fold an unpack(pack(x)) to x.
4209  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4210  if (unPackOp.getSourceType() != packOp.getDestType())
4211  return failure();
4212  if (packOp.getPaddingValue() ||
4213  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4214  !haveSameTiles(packOp, unPackOp))
4215  return failure();
4216  rewriter.replaceOp(packOp, unPackOp.getSource());
4217  return success();
4218  }
4219 
4220  // Fold optional PaddingValue operand away if padding is not needed.
4221  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4222  rewriter.startOpModification(packOp);
4223  packOp.getPaddingValueMutable().clear();
4224  rewriter.finalizeOpModification(packOp);
4225  return success();
4226  }
4227 
4228  // Insert tensor.cast ops if static shape inference is available..
4229  SmallVector<int64_t> srcShape, destShape;
4230  if (inferStaticShape(packOp, srcShape, destShape)) {
4231  Location loc = packOp.getLoc();
4232  Value source = packOp.getSource();
4233  if (srcShape != packOp.getSourceType().getShape()) {
4234  auto newSrcType = packOp.getSourceType().clone(srcShape);
4235  source =
4236  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4237  }
4238  Value dest = packOp.getDest();
4239  if (destShape != packOp.getDestType().getShape()) {
4240  auto newDestType = packOp.getDestType().clone(destShape);
4241  dest =
4242  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4243  }
4244  Value newOp = rewriter.create<tensor::PackOp>(
4245  loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4246  packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4247  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4248  packOp, packOp.getResult().getType(), newOp);
4249  return success();
4250  }
4251 
4252  return failure();
4253 }
4254 
4255 template <typename PackOrUnpackOp>
4256 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4257  RankedTensorType packedTensorType) {
4258  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4259  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4260  "Function meant for pack/unpack");
4261  // This is a pad if packing only adds ones and we don't transpose dimensions.
4262 
4263  // Check that we are not transposing any dimensions.
4264  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4265  int64_t numPackedDims = innerDimsPos.size();
4266  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4267  if (orderedDims != innerDimsPos) {
4268  // Dimensions don't happen in order.
4269  return false;
4270  }
4271 
4272  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4273  int64_t packedRank = packedTensorType.getRank();
4274  // At this point we know that we are taking numPackedDims outer
4275  // dimensions and pushing them all the way as the inner most dimensions.
4276  // What's left on the outer most dimensions is, in this order:
4277  // - the factor of the packed dimensions, then
4278  // - the untouched dimensions
4279  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4280  // if all the dimensions that bubble outerward are ones.
4281  // Therefore check that all the dimensions but the numPackedDims inner most
4282  // ones are ones.
4283  return llvm::all_of(
4284  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4285  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4286 }
4287 
4288 bool PackOp::isLikePad() {
4289  auto packedTensorType =
4290  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4291  return isLikePadUnPad(*this, packedTensorType);
4292 }
4293 
4294 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4295  std::optional<Attribute> paddingValue;
4296  if (auto pad = adaptor.getPaddingValue())
4297  paddingValue = pad;
4298  if (OpFoldResult reshapedSource = reshapeConstantSource(
4299  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4300  getDestType(), paddingValue))
4301  return reshapedSource;
4302  return {};
4303 }
4304 
4305 //===----------------------------------------------------------------------===//
4306 // UnPackOp
4307 //===----------------------------------------------------------------------===//
4308 
4309 void UnPackOp::getAsmResultNames(
4310  function_ref<void(Value, StringRef)> setNameFn) {
4311  setNameFn(getResult(), "unpack");
4312 }
4313 
4314 LogicalResult
4316  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4317  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4318 }
4319 
4320 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4321  return getDimAndTileMappingImpl(*this);
4322 }
4323 
4324 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4325  return getMixedTilesImpl(*this);
4326 }
4327 
4328 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4329  return getStaticTilesImpl(*this);
4330 }
4331 
4332 LogicalResult UnPackOp::verify() {
4333  return commonVerifierPackAndUnPackOp(*this);
4334 }
4335 
4336 Speculation::Speculatability UnPackOp::getSpeculatability() {
4337  // See PackOp::getSpeculatability.
4338  if (!areTilesAndTiledDimsAllConstant(*this))
4340 
4342 }
4343 
4344 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4345  Value dest, ArrayRef<int64_t> innerDimsPos,
4346  ArrayRef<OpFoldResult> innerTiles,
4347  ArrayRef<int64_t> outerDimsPerm) {
4348  assert(innerDimsPos.size() == innerTiles.size() &&
4349  "number of tile sizes specified must match the specified number of "
4350  "original dimensions to be tiled");
4351  SmallVector<int64_t> staticTileSizes;
4352  SmallVector<Value> dynamicTileSizes;
4353  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4354  build(builder, state, dest.getType(), source, dest,
4355  outerDimsPerm.empty() ? nullptr
4356  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4357  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4358  builder.getDenseI64ArrayAttr(staticTileSizes));
4359 }
4360 
4361 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4362  Value source,
4363  ArrayRef<OpFoldResult> innerTileSizes,
4364  ArrayRef<int64_t> innerDimsPos,
4365  ArrayRef<int64_t> outerDimsPerm) {
4366  AffineExpr sym0, sym1;
4367  bindSymbols(b.getContext(), sym0, sym1);
4368  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4369  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4370  };
4371 
4372  SmallVector<OpFoldResult> mixedSizes;
4373  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4374  for (auto i :
4375  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4376  if (srcType.isDynamicDim(i))
4377  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4378  else
4379  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4380  }
4381  if (!outerDimsPerm.empty()) {
4382  applyPermutationToVector<OpFoldResult>(
4383  mixedSizes, invertPermutationVector(outerDimsPerm));
4384  }
4385 
4386  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4387  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4388 
4389  auto elemType = srcType.getElementType();
4390  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4391 }
4392 
4393 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4394  Value transposedSource,
4395  ArrayRef<int64_t> innerPermutation,
4396  ArrayRef<int64_t> outerPermutation) {
4397  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4398  *this, innerPermutation, outerPermutation);
4399  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4400  metadata.innerDimsPos, metadata.innerTiles,
4401  metadata.outerDimsPerm);
4402 }
4403 
4404 /// Returns true if the `srcShape` or `destShape` is different from the one in
4405 /// `op` and populates each with the inferred static shape.
4406 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4407  SmallVectorImpl<int64_t> &destShape) {
4408  bool changeNeeded = false;
4409  srcShape.assign(op.getSourceType().getShape().begin(),
4410  op.getSourceType().getShape().end());
4411  destShape.assign(op.getDestType().getShape().begin(),
4412  op.getDestType().getShape().end());
4413  llvm::SmallSetVector<int64_t, 4> innerDims;
4414  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4415  SmallVector<int64_t> inverseOuterDimsPerm;
4416  if (!op.getOuterDimsPerm().empty())
4417  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4418  int destRank = op.getDestRank();
4419  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4420  if (innerDims.contains(i))
4421  continue;
4422  int64_t srcPos = i;
4423  int64_t destPos = i;
4424  if (!inverseOuterDimsPerm.empty())
4425  srcPos = inverseOuterDimsPerm[destPos];
4426  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4427  ShapedType::isDynamic(destShape[destPos])) {
4428  continue;
4429  }
4430  int64_t size = srcShape[srcPos];
4431  if (ShapedType::isDynamic(size))
4432  size = destShape[destPos];
4433  srcShape[srcPos] = size;
4434  destShape[destPos] = size;
4435  changeNeeded = true;
4436  }
4437  return changeNeeded;
4438 }
4439 
4440 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4441  PatternRewriter &rewriter) {
4442  /// pack(unpack(x)) -> x
4443  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4444  if (packOp.getDestType() != unPackOp.getSourceType())
4445  return failure();
4446  if (packOp.getPaddingValue() ||
4447  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4448  !haveSameTiles(packOp, unPackOp))
4449  return failure();
4450  rewriter.replaceOp(unPackOp, packOp.getSource());
4451  return success();
4452  }
4453  /// unpack(destinationStyleOp(x)) -> unpack(x)
4454  if (auto dstStyleOp =
4455  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4456  auto destValue = cast<OpResult>(unPackOp.getDest());
4457  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4458  rewriter.modifyOpInPlace(unPackOp,
4459  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4460  return success();
4461  }
4462 
4463  // Insert tensor.cast ops if static shape inference is available..
4464  SmallVector<int64_t> srcShape, destShape;
4465  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4466  Location loc = unPackOp.getLoc();
4467  Value source = unPackOp.getSource();
4468  if (srcShape != unPackOp.getSourceType().getShape()) {
4469  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4470  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4471  unPackOp.getSource());
4472  }
4473  Value dest = unPackOp.getDest();
4474  if (destShape != unPackOp.getDestType().getShape()) {
4475  auto newDestType = unPackOp.getDestType().clone(destShape);
4476  dest =
4477  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4478  }
4479  Value newOp = rewriter.create<tensor::UnPackOp>(
4480  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4481  unPackOp.getOuterDimsPerm());
4482  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4483  unPackOp, unPackOp.getResult().getType(), newOp);
4484  return success();
4485  }
4486 
4487  return failure();
4488 }
4489 
4490 bool UnPackOp::isLikeUnPad() {
4491  RankedTensorType packedTensorType = getSourceType();
4492  return isLikePadUnPad(*this, packedTensorType);
4493 }
4494 
4495 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4496  if (OpFoldResult reshapedSource = reshapeConstantSource(
4497  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4498  getResult().getType()))
4499  return reshapedSource;
4500  return {};
4501 }
4502 
4503 //===----------------------------------------------------------------------===//
4504 // Common Canonicalizers and Folders.
4505 //===----------------------------------------------------------------------===//
4506 
4507 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4508 /// the `tensor.cast` has source that is more static than the consuming op.
4509 ///
4510 /// Example:
4511 /// ```mlir
4512 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4513 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4514 /// ```
4515 ///
4516 /// folds into:
4517 ///
4518 /// ```mlir
4519 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4520 /// ```
4521 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4522 /// can add the pattern to their canonicalizers.
4524  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4526  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4527 
4528  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4529  PatternRewriter &rewriter) const override {
4530  // InsertSliceOp has its own logic about folding tensor.cast ops.
4531  if (isa<InsertSliceOp>(op.getOperation()))
4532  return failure();
4533 
4534  // Exclude DPS ops that are also LoopLike from this interface as they
4535  // might need special handling of attached regions.
4536  if (isa<LoopLikeOpInterface>(op.getOperation()))
4537  return failure();
4538 
4539  // If no operand comes from a tensor::CastOp and can be folded then fail.
4540  bool hasTensorCastOperand =
4541  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4542  if (llvm::isa<BlockArgument>(opOperand.get()))
4543  return false;
4544  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4545  return castOp && canFoldIntoConsumerOp(castOp);
4546  });
4547  if (!hasTensorCastOperand)
4548  return failure();
4549 
4550  SmallVector<Type, 4> newResultTypes(op->getResultTypes());
4551  SmallVector<Value, 4> newOperands;
4552  newOperands.reserve(op->getNumOperands());
4553  // Assumes that the result has dpsInits followed by nonDpsInits.
4554  int64_t dpsInitIdx = 0;
4555  for (OpOperand &opOperand : op->getOpOperands()) {
4556  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4557  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4558  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4559  if (op.isDpsInit(&opOperand) &&
4560  !llvm::isa<MemRefType>(newOperands.back().getType()))
4561  newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4562  }
4563 
4564  // Clone op.
4565  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4566  SmallVector<Value, 4> replacements;
4567  replacements.reserve(newOp->getNumResults());
4568  for (auto [oldResult, newResult] :
4569  llvm::zip(op->getResults(), newOp->getResults())) {
4570  if (newResult.getType() != oldResult.getType()) {
4571  replacements.push_back(rewriter.create<tensor::CastOp>(
4572  op->getLoc(), oldResult.getType(), newResult));
4573  } else {
4574  replacements.push_back(newResult);
4575  }
4576  }
4577  rewriter.replaceOp(op, replacements);
4578 
4579  return success();
4580  }
4581 };
4582 
4583 //===----------------------------------------------------------------------===//
4584 // TensorDialect
4585 //===----------------------------------------------------------------------===//
4586 
4587 void TensorDialect::getCanonicalizationPatterns(
4588  RewritePatternSet &results) const {
4590 }
4591 
4592 //===----------------------------------------------------------------------===//
4593 // TableGen'd op method definitions
4594 //===----------------------------------------------------------------------===//
4595 
4596 #define GET_OP_CLASSES
4597 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:4100
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:387
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1292
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3828
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2167
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3648
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3679
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:4158
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2894
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3981
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3966
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3663
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2588
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2440
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2462
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4172
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3709
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4145
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2549
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4256
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3724
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:179
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3636
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3692
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1078
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2890
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:135
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4131
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2608
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1769
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:951
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
UnitAttr getUnitAttr()
Definition: Builders.cpp:118
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:379
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:75
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:255
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:266
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:97
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:358
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2400
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:349
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:319
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2867
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2487
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:123
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:74
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:267
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:381
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1183
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4524
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4528
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2419
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2420
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2407
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2408
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)