MLIR  20.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <algorithm>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::tensor;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 /// Materialize a single constant operation from a given attribute value with
43 /// the desired resultant type.
45  Attribute value, Type type,
46  Location loc) {
47  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
48  return op;
49  if (complex::ConstantOp::isBuildableWith(value, type))
50  return builder.create<complex::ConstantOp>(loc, type,
51  llvm::cast<ArrayAttr>(value));
52  return nullptr;
53 }
54 
56  int64_t dim) {
57  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
59  if (tensorType.isDynamicDim(dim))
60  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
61 
62  return builder.getIndexAttr(tensorType.getDimSize(dim));
63 }
64 
66  Location loc, Value value) {
67  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
69  for (int64_t i = 0; i < tensorType.getRank(); ++i)
70  result.push_back(getMixedSize(builder, loc, value, i));
71  return result;
72 }
73 
75  OpResult opResult) {
76  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
77  assert(tensorType && "expected tensor type");
78 
79  // If the op has a destination, it implements DestinationStyleOpInterface and
80  // we can query the destination operand from that interface.
81  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
82  if (destOp)
83  return destOp.getTiedOpOperand(opResult)->get();
84 
85  // Otherwise, create a new destination tensor with the same shape.
87  b.setInsertionPoint(opResult.getDefiningOp());
88 
89  // Compute sizes.
90  SmallVector<OpFoldResult> mixedSizes;
91  if (!tensorType.hasStaticShape()) {
92  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
93  ReifiedRankedShapedTypeDims reifiedShapes;
94  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
95  return failure();
96  mixedSizes = reifiedShapes[opResult.getResultNumber()];
97  } else {
98  // Static shape: Take static sizes directly.
99  for (int64_t sz : tensorType.getShape())
100  mixedSizes.push_back(b.getIndexAttr(sz));
101  }
102 
103  // Create empty tensor.
104  Value emptyTensor =
105  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
106  return emptyTensor;
107 }
108 
110  Operation *op,
111  SmallVector<Value> &result) {
112  for (OpResult opResult : op->getResults()) {
113  if (llvm::isa<TensorType>(opResult.getType())) {
114  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
115  if (failed(destination))
116  return failure();
117  result.push_back(*destination);
118  }
119  }
120  return success();
121 }
122 
124  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
125  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
126  return rtp1.getShape() == rtp2.getShape() &&
127  rtp1.getElementType() == rtp2.getElementType();
128  return false;
129  }
130  return tp1 == tp2; // default implementation
131 }
132 
133 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
134 /// rank-extending tensor.insert_slice op.
135 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
136  ArrayRef<OpFoldResult> mixedSizes) {
137  llvm::SmallBitVector droppedDims(mixedSizes.size());
138  int64_t shapePos = reducedShape.size() - 1;
139 
140  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
141  size_t idx = mixedSizes.size() - size.index() - 1;
142  // Rank-reduced dims must have a static unit dimension.
143  bool isStaticUnitSize =
144  size.value().is<Attribute>() &&
145  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
146 
147  if (shapePos < 0) {
148  // There are no more dims in the reduced shape. All remaining sizes must
149  // be rank-reduced dims.
150  assert(isStaticUnitSize && "expected unit dim");
151  droppedDims.set(idx);
152  continue;
153  }
154 
155  // Dim is preserved if the size is not a static 1.
156  if (!isStaticUnitSize) {
157  --shapePos;
158  continue;
159  }
160 
161  // Dim is preserved if the reduced shape dim is also 1.
162  if (reducedShape[shapePos] == 1) {
163  --shapePos;
164  continue;
165  }
166 
167  // Otherwise: Dim is dropped.
168  droppedDims.set(idx);
169  }
170 
171  assert(shapePos < 0 && "dimension mismatch");
172  return droppedDims;
173 }
174 
175 /// Given a ranked tensor type and a range of values that defines its dynamic
176 /// dimension sizes, turn all dynamic sizes that have a constant value into
177 /// static dimension sizes.
178 static RankedTensorType
179 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
180  SmallVector<Value> &foldedDynamicSizes) {
181  SmallVector<int64_t> staticShape(type.getShape());
182  assert(type.getNumDynamicDims() ==
183  static_cast<int64_t>(dynamicSizes.size()) &&
184  "incorrect number of dynamic sizes");
185 
186  // Compute new static and dynamic sizes.
187  unsigned ctr = 0;
188  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189  if (type.isDynamicDim(i)) {
190  Value dynamicSize = dynamicSizes[ctr++];
191  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192  if (cst.has_value()) {
193  // Dynamic size must be non-negative.
194  if (cst.value() < 0) {
195  foldedDynamicSizes.push_back(dynamicSize);
196  continue;
197  }
198  staticShape[i] = *cst;
199  } else {
200  foldedDynamicSizes.push_back(dynamicSize);
201  }
202  }
203  }
204 
205  return RankedTensorType::get(staticShape, type.getElementType(),
206  type.getEncoding());
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // BitcastOp
211 //===----------------------------------------------------------------------===//
212 
213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214  if (inputs.size() != 1 || outputs.size() != 1)
215  return false;
216  Type a = inputs.front(), b = outputs.front();
217  auto aT = dyn_cast<TensorType>(a);
218  auto bT = dyn_cast<TensorType>(b);
219  if (!aT || !bT)
220  return false;
221 
222  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223  return false;
224 
225  return succeeded(verifyCompatibleShape(aT, bT));
226 }
227 
228 namespace {
229 
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231 /// operation.
232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
234 
235  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236  PatternRewriter &rewriter) const final {
237  auto tensorBitcastOperand =
238  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239  if (!tensorBitcastOperand)
240  return failure();
241 
242  auto resultType = cast<TensorType>(tensorBitcast.getType());
243  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244  tensorBitcastOperand.getOperand());
245  return success();
246  }
247 };
248 
249 } // namespace
250 
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252  MLIRContext *context) {
253  results.add<ChainedTensorBitcast>(context);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // CastOp
258 //===----------------------------------------------------------------------===//
259 
260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261  setNameFn(getResult(), "cast");
262 }
263 
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
267  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269 
270  // Requires RankedTensorType.
271  if (!sourceType || !targetType)
272  return false;
273 
274  // Requires same elemental type.
275  if (sourceType.getElementType() != targetType.getElementType())
276  return false;
277 
278  // Requires same rank.
279  if (sourceType.getRank() != targetType.getRank())
280  return false;
281 
282  // Requires same encoding.
283  if (sourceType.getEncoding() != targetType.getEncoding())
284  return false;
285 
286  // If cast is towards more static sizes along any dimension, don't fold.
287  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288  if (!ShapedType::isDynamic(std::get<0>(t)) &&
289  ShapedType::isDynamic(std::get<1>(t)))
290  return false;
291  }
292 
293  return true;
294 }
295 
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
302 ///
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
306 ///
307 /// Example:
308 /// ```mlir
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
311 /// ```
312 ///
313 /// folds into:
314 ///
315 /// ```mlir
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
317 /// ```
319  if (!castOp)
320  return false;
321 
322  // Can fold if the source of cast has at least as much static information as
323  // its results.
324  return preservesStaticInformation(castOp.getType(),
325  castOp.getSource().getType());
326 }
327 
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
331 /// being from different dialects. Returns true when all conditions are met:
332 /// 1. source and result and ranked tensors with same element type and rank.
333 /// 2. the result type has more static information than the source.
334 ///
335 /// Example:
336 /// ```mlir
337 /// %1 = producer ... : tensor<?x?xf32>
338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
339 /// ```
340 ///
341 /// can be canonicalized to :
342 ///
343 /// ```mlir
344 /// %2 = producer ... : tensor<8x16xf32>
345 /// ```
346 /// Not all ops might be canonicalizable this way, but for those that can be,
347 /// this method provides a check that it is worth doing the canonicalization.
349  if (!castOp)
350  return false;
351  return preservesStaticInformation(castOp.getSource().getType(),
352  castOp.getType());
353 }
354 
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
356 /// that can be folded.
358  bool folded = false;
359  for (OpOperand &operand : op->getOpOperands()) {
360  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
361  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
362  operand.set(castOp.getOperand());
363  folded = true;
364  }
365  }
366  return success(folded);
367 }
368 
369 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
370  if (inputs.size() != 1 || outputs.size() != 1)
371  return false;
372  Type a = inputs.front(), b = outputs.front();
373  auto aT = llvm::dyn_cast<TensorType>(a);
374  auto bT = llvm::dyn_cast<TensorType>(b);
375  if (!aT || !bT)
376  return false;
377 
378  if (aT.getElementType() != bT.getElementType())
379  return false;
380 
381  return succeeded(verifyCompatibleShape(aT, bT));
382 }
383 
384 /// Compute a TensorType that has the joined shape knowledge of the two
385 /// given TensorTypes. The element types need to match.
387  assert(one.getElementType() == two.getElementType());
388 
389  if (!one.hasRank())
390  return two;
391  if (!two.hasRank())
392  return one;
393 
394  int64_t rank = one.getRank();
395  if (rank != two.getRank())
396  return {};
397 
399  join.reserve(rank);
400  for (int64_t i = 0; i < rank; ++i) {
401  if (one.isDynamicDim(i)) {
402  join.push_back(two.getDimSize(i));
403  continue;
404  }
405  if (two.isDynamicDim(i)) {
406  join.push_back(one.getDimSize(i));
407  continue;
408  }
409  if (one.getDimSize(i) != two.getDimSize(i))
410  return {};
411  join.push_back(one.getDimSize(i));
412  }
413  return RankedTensorType::get(join, one.getElementType());
414 }
415 
416 namespace {
417 
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
419 /// operation if doing so does not remove runtime constraints.
420 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
422 
423  LogicalResult matchAndRewrite(CastOp tensorCast,
424  PatternRewriter &rewriter) const final {
425  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
426 
427  if (!tensorCastOperand)
428  return failure();
429 
430  auto sourceType =
431  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
434 
435  // We can remove the intermediate cast if joining all three produces the
436  // same result as just joining the source and result shapes.
437  auto firstJoin =
438  joinShapes(joinShapes(sourceType, intermediateType), resultType);
439 
440  // The join might not exist if the cast sequence would fail at runtime.
441  if (!firstJoin)
442  return failure();
443 
444  // The newJoin always exists if the above join exists, it might just contain
445  // less information. If so, we cannot drop the intermediate cast, as doing
446  // so would remove runtime checks.
447  auto newJoin = joinShapes(sourceType, resultType);
448  if (firstJoin != newJoin)
449  return failure();
450 
451  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452  tensorCastOperand.getOperand());
453  return success();
454  }
455 };
456 
457 /// Fold tensor.cast into tesor.extract_slice producer.
458 /// Example:
459 /// ```
460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 /// tensor<128x512xf32> to tensor<?x512xf32>
462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
463 /// ```
464 /// ->
465 /// ```
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 /// tensor<128x512xf32> to tensor<16x512xf32>
468 /// ```
469 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
471 
472  LogicalResult matchAndRewrite(CastOp tensorCast,
473  PatternRewriter &rewriter) const final {
474  auto extractOperand =
475  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
476 
477  // Cannot fold cast to unranked tensor.
478  auto rankedResultType =
479  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480  if (!rankedResultType)
481  return failure();
482 
483  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
484  rankedResultType.getShape() ==
485  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
486  .getShape())
487  return failure();
488 
489  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
490  auto dimMask = computeRankReductionMask(
491  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
492  size_t dimIndex = 0;
493  for (size_t i = 0, e = sizes.size(); i < e; i++) {
494  if (dimMask && dimMask->count(i))
495  continue;
496  int64_t dim = rankedResultType.getShape()[dimIndex++];
497  if (ShapedType::isDynamic(dim))
498  continue;
499  sizes[i] = rewriter.getIndexAttr(dim);
500  }
501 
502  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503  tensorCast, rankedResultType, extractOperand.getSource(),
504  extractOperand.getMixedOffsets(), sizes,
505  extractOperand.getMixedStrides());
506  return success();
507  }
508 };
509 
510 } // namespace
511 
512 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
513  MLIRContext *context) {
514  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // ConcatOp
519 //===----------------------------------------------------------------------===//
520 
521 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
522  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
523  auto tensorTypes =
524  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
525  return llvm::cast<RankedTensorType>(type);
526  }));
527  int64_t concatRank = tensorTypes[0].getRank();
528 
529  // The concatenation dim must be in the range [0, rank).
530  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
531 
532  SmallVector<int64_t> sizes(concatRank);
533  for (int64_t i = 0, e = concatRank; i < e; ++i) {
534  if (i == dim)
535  continue;
536  SaturatedInteger size;
537  for (auto tensorType : tensorTypes)
538  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
539  sizes[i] = size.asInteger();
540  }
541  auto concatSize = SaturatedInteger::wrap(0);
542  for (auto tensorType : tensorTypes)
543  concatSize =
544  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
545  sizes[dim] = concatSize.asInteger();
546  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
547 }
548 
549 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
550  ValueRange inputs) {
551  FailureOr<RankedTensorType> resultType =
552  inferResultType(dim, inputs.getTypes());
553  assert(succeeded(resultType) && "failed to infer concatenation result type");
554  build(builder, result, *resultType, dim, inputs);
555 }
556 
557 LogicalResult ConcatOp::verify() {
558  if (getInputs().size() < 1)
559  return emitOpError("requires at least one input");
560 
562  for (auto input : getInputs())
563  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
564 
565  RankedTensorType resultType = getResultType();
566  int64_t resultRank = getRank();
567  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568  return type.getRank() != resultRank;
569  }))
570  return emitOpError("rank of concatenated inputs must match result rank");
571 
572  Type resultElementType = resultType.getElementType();
573  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574  return type.getElementType() != resultElementType;
575  }))
576  return emitOpError("inputs and result element type must match");
577 
578  int64_t dim = getDim();
579  if (dim >= resultRank)
580  return emitOpError("concatenation dim must be less than the tensor rank");
581 
582  SmallVector<int64_t> sizes(resultRank);
583  for (int64_t i = 0, e = resultRank; i < e; ++i) {
584  if (i == dim)
585  continue;
586  SaturatedInteger size;
587  for (auto tensorType : inputTypes) {
588  FailureOr<SaturatedInteger> maybeSize =
589  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
590  if (failed(maybeSize))
591  return emitOpError("static concatenation size mismatch along ")
592  << "non-concatenated dimension " << i;
593  size = *maybeSize;
594  }
595  sizes[i] = size.asInteger();
596  }
597  auto concatSize = SaturatedInteger::wrap(0);
598  for (auto tensorType : inputTypes)
599  concatSize =
600  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
601  sizes[dim] = concatSize.asInteger();
602  auto inferredResultType =
603  RankedTensorType::get(sizes, inputTypes[0].getElementType());
604 
605  for (auto [inferredSize, actualSize] :
606  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608  ShapedType::isDynamic(actualSize);
609  if (!hasDynamic && inferredSize != actualSize)
610  return emitOpError("result type ")
611  << resultType << "does not match inferred shape "
612  << inferredResultType << " static sizes";
613  }
614 
615  return success();
616 }
617 
618 LogicalResult
620  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
621  ValueRange inputs = getInputs();
622  int64_t dim = getDim();
623  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
624 
625  Value init = inputs[0];
626  int64_t rank = getType().getRank();
627 
628  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
629 
630  // Pre-populate the result sizes with as much static information as possible
631  // from the given result type, as well as the inferred result type, otherwise
632  // use the dim sizes from the first input.
633  for (int64_t i = 0; i < rank; ++i) {
634  if (i == dim)
635  continue;
636  if (!getType().isDynamicDim(i)) {
637  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
638  } else if (!inferredResultType.isDynamicDim(i)) {
639  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
640  builder, getLoc(),
641  builder.getIndexAttr(inferredResultType.getDimSize(i)));
642  } else {
643  reifiedReturnShapes[0][i] =
644  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
645  }
646  }
647 
648  if (getType().isDynamicDim(dim)) {
649  // Take the sum of the input sizes along the concatenated dim.
650  AffineExpr sum = builder.getAffineDimExpr(0);
651  SmallVector<OpFoldResult> sizes = {
652  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
653  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
654  sum = sum + builder.getAffineDimExpr(idx + 1);
655  sizes.push_back(
656  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
657  }
658  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
659  builder, getLoc(),
660  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
661  } else {
662  // If the result shape is static along the concatenated dim, use the static
663  // shape.
664  reifiedReturnShapes[0][dim] =
665  builder.getIndexAttr(getType().getDimSize(dim));
666  }
667  return success();
668 }
669 
670 void ConcatOp::getAsmResultNames(
671  function_ref<void(Value, StringRef)> setNameFn) {
672  setNameFn(getResult(), "concat");
673 }
674 
675 OpFoldResult ConcatOp::fold(FoldAdaptor) {
676  ValueRange inputs = getInputs();
677  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
678  return inputs[0];
679  return {};
680 }
681 
682 namespace {
683 /// Fold a concat op with a single input to a cast.
684 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
686 
687  LogicalResult matchAndRewrite(ConcatOp concatOp,
688  PatternRewriter &rewriter) const override {
689  if (concatOp.getInputs().size() != 1)
690  return failure();
691  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
692  concatOp.getInputs()[0]);
693  return success();
694  }
695 };
696 } // namespace
697 
698 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
699  MLIRContext *context) {
700  results.add<SingleInputConcatOp>(context);
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // DimOp
705 //===----------------------------------------------------------------------===//
706 
707 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
708  setNameFn(getResult(), "dim");
709 }
710 
711 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
712  int64_t index) {
713  auto loc = result.location;
714  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
715  build(builder, result, source, indexValue);
716 }
717 
718 std::optional<int64_t> DimOp::getConstantIndex() {
719  return getConstantIntValue(getIndex());
720 }
721 
722 Speculation::Speculatability DimOp::getSpeculatability() {
723  auto constantIndex = getConstantIndex();
724  if (!constantIndex)
726 
727  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
728  if (!rankedSourceType)
730 
731  if (rankedSourceType.getRank() <= constantIndex)
733 
735 }
736 
737 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
738  // All forms of folding require a known index.
739  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
740  if (!index)
741  return {};
742 
743  // Folding for unranked types (UnrankedTensorType) is not supported.
744  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
745  if (!tensorType)
746  return {};
747 
748  // Out of bound indices produce undefined behavior but are still valid IR.
749  // Don't choke on them.
750  int64_t indexVal = index.getInt();
751  if (indexVal < 0 || indexVal >= tensorType.getRank())
752  return {};
753 
754  // Fold if the shape extent along the given index is known.
755  if (!tensorType.isDynamicDim(index.getInt())) {
756  Builder builder(getContext());
757  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
758  }
759 
760  Operation *definingOp = getSource().getDefiningOp();
761 
762  // Fold dim to the operand of tensor.generate.
763  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
764  auto resultType =
765  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
766  // The case where the type encodes the size of the dimension is handled
767  // above.
768  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
769 
770  // Find the operand of the fromElements that corresponds to this index.
771  auto dynExtents = fromElements.getDynamicExtents().begin();
772  for (auto dim : resultType.getShape().take_front(index.getInt()))
773  if (ShapedType::isDynamic(dim))
774  dynExtents++;
775 
776  return Value{*dynExtents};
777  }
778 
779  // The size at the given index is now known to be a dynamic size.
780  unsigned unsignedIndex = index.getValue().getZExtValue();
781 
782  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
783  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
784  // `resolve-shaped-type-result-dims` pass.
785  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
786  sliceOp.isDynamicSize(unsignedIndex)) {
787  return {sliceOp.getDynamicSize(unsignedIndex)};
788  }
789  }
790 
791  // dim(cast) -> dim
792  if (succeeded(foldTensorCast(*this)))
793  return getResult();
794 
795  return {};
796 }
797 
798 namespace {
799 /// Fold dim of a cast into the dim of the source of the tensor cast.
800 struct DimOfCastOp : public OpRewritePattern<DimOp> {
802 
803  LogicalResult matchAndRewrite(DimOp dimOp,
804  PatternRewriter &rewriter) const override {
805  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
806  if (!castOp)
807  return failure();
808  Value newSource = castOp.getOperand();
809  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
810  return success();
811  }
812 };
813 
814 /// Fold dim of a destination passing style op into the dim of the corresponding
815 /// init.
816 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
818 
819  LogicalResult matchAndRewrite(DimOp dimOp,
820  PatternRewriter &rewriter) const override {
821  auto source = dimOp.getSource();
822  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
823  if (!destOp)
824  return failure();
825 
826  auto resultIndex = cast<OpResult>(source).getResultNumber();
827  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
828 
829  rewriter.modifyOpInPlace(
830  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
831  return success();
832  }
833 };
834 
835 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
836 /// operand.
837 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
839 
840  LogicalResult matchAndRewrite(DimOp dim,
841  PatternRewriter &rewriter) const override {
842  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
843 
844  if (!reshape)
845  return failure();
846 
847  // Since tensors are immutable we don't need to worry about where to place
848  // the extract call
849  rewriter.setInsertionPointAfter(dim);
850  Location loc = dim.getLoc();
851  Value extract =
852  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
853  if (extract.getType() != dim.getType())
854  extract =
855  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
856  rewriter.replaceOp(dim, extract);
857  return success();
858  }
859 };
860 } // namespace
861 
862 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
863  MLIRContext *context) {
864  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // EmptyOp
869 //===----------------------------------------------------------------------===//
870 
871 void EmptyOp::build(OpBuilder &builder, OperationState &result,
872  ArrayRef<int64_t> staticShape, Type elementType,
873  Attribute encoding) {
874  assert(all_of(staticShape,
875  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
876  "expected only static sizes");
877  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
878 }
879 
880 void EmptyOp::build(OpBuilder &builder, OperationState &result,
881  ArrayRef<int64_t> staticShape, Type elementType,
882  ValueRange dynamicSizes, Attribute encoding) {
883  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
884  build(builder, result, tensorType, dynamicSizes);
885 }
886 
887 void EmptyOp::build(OpBuilder &builder, OperationState &result,
888  ArrayRef<OpFoldResult> sizes, Type elementType,
889  Attribute encoding) {
890  SmallVector<int64_t> staticShape;
891  SmallVector<Value> dynamicSizes;
892  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
893  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
894 }
895 
896 LogicalResult EmptyOp::verify() {
897  if (getType().getNumDynamicDims() !=
898  static_cast<int64_t>(getDynamicSizes().size()))
899  return emitOpError("incorrect number of dynamic sizes, has ")
900  << getDynamicSizes().size() << ", expected "
901  << getType().getNumDynamicDims();
902  return success();
903 }
904 
905 LogicalResult
907  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
908  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
909  unsigned ctr = 0;
910  for (int64_t i = 0; i < getType().getRank(); ++i) {
911  if (getType().isDynamicDim(i)) {
912  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
913  } else {
914  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
915  }
916  }
917  return success();
918 }
919 
920 Value EmptyOp::getDynamicSize(unsigned idx) {
921  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
922  unsigned ctr = 0;
923  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
924  if (getType().isDynamicDim(i))
925  ++ctr;
926  return getDynamicSizes()[ctr];
927 }
928 
931  unsigned ctr = 0;
932  OpBuilder b(getContext());
933  for (int64_t i = 0; i < getType().getRank(); ++i) {
934  if (getType().isDynamicDim(i)) {
935  result.push_back(getDynamicSizes()[ctr++]);
936  } else {
937  result.push_back(b.getIndexAttr(getType().getShape()[i]));
938  }
939  }
940  return result;
941 }
942 
943 namespace {
944 /// Change the type of the result of a `tensor.empty` by making the result
945 /// type statically sized along dimensions that in the original operation were
946 /// defined as dynamic, but the size was defined using a `constant` op. For
947 /// example
948 ///
949 /// %c5 = arith.constant 5: index
950 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
951 ///
952 /// to
953 ///
954 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
955 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
957 
958  LogicalResult matchAndRewrite(EmptyOp op,
959  PatternRewriter &rewriter) const override {
960  SmallVector<Value> foldedDynamicSizes;
961  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
962  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
963 
964  // Stop here if no dynamic size was promoted to static.
965  if (foldedTensorType == op.getType())
966  return failure();
967 
968  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
969  foldedDynamicSizes);
970  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
971  return success();
972  }
973 };
974 
975 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
977 
978  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
979  PatternRewriter &rewriter) const override {
980  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
981  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
982  if (!emptyTensorOp || !maybeConstantIndex)
983  return failure();
984  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
985  return failure();
986  rewriter.replaceOp(dimOp,
987  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
988  return success();
989  }
990 };
991 
992 /// Canonicalize
993 ///
994 /// ```mlir
995 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
996 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
997 /// ```
998 ///
999 /// into
1000 ///
1001 /// ```mlir
1002 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1003 /// ```
1004 ///
1005 /// This assumes the input program is correct in terms of its shape. So it is
1006 /// safe to assume that `%d0` is in fact 4.
1007 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1009 
1010  LogicalResult matchAndRewrite(CastOp castOp,
1011  PatternRewriter &rewriter) const override {
1012  if (!canFoldIntoProducerOp(castOp))
1013  return failure();
1014  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1015  if (!producer)
1016  return failure();
1017 
1018  auto resultType =
1019  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1020  ArrayRef<int64_t> resultShape = resultType.getShape();
1021  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1022  SmallVector<OpFoldResult> newMixedSizes;
1023  newMixedSizes.reserve(currMixedSizes.size());
1024  assert(resultShape.size() == currMixedSizes.size() &&
1025  "mismatch in result shape and sizes of empty op");
1026  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1027  int64_t newDim = std::get<0>(it);
1028  OpFoldResult currDim = std::get<1>(it);
1029  // Case 1: The empty tensor dim is static. Check that the tensor cast
1030  // result dim matches.
1031  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1032  if (ShapedType::isDynamic(newDim) ||
1033  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1034  // Something is off, the cast result shape cannot be more dynamic
1035  // than the empty tensor result shape (enforced by
1036  // `canFoldIntoProducer`). Abort for now.
1037  return rewriter.notifyMatchFailure(
1038  producer, "mismatch in static value of shape of empty tensor "
1039  "result and cast result");
1040  }
1041  newMixedSizes.push_back(attr);
1042  continue;
1043  }
1044 
1045  // Case 2 : The tensor cast shape is static, but empty tensor result
1046  // shape is dynamic.
1047  if (!ShapedType::isDynamic(newDim)) {
1048  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1049  continue;
1050  }
1051 
1052  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1053  // shape is dynamic. Use the dynamic value from the empty tensor op.
1054  newMixedSizes.push_back(currDim);
1055  }
1056 
1057  // TODO: Do not drop tensor encoding.
1058  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1059  resultType.getElementType());
1060  return success();
1061  }
1062 };
1063 
1064 } // namespace
1065 
1066 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1067  MLIRContext *context) {
1068  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1069  ReplaceEmptyTensorStaticShapeDims>(context);
1070 }
1071 
1072 /// Try to remove a tensor operation if it would only reshape a constant.
1073 /// Removes the op and replaces the constant with a new constant of the result
1074 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1075 /// splat value matches the value in the attribute.
1076 static OpFoldResult
1078  std::optional<Attribute> cst = std::nullopt) {
1079  if (source && source.isSplat() && result.hasStaticShape() &&
1080  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1081  return source.resizeSplat(result);
1082 
1083  return {};
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // ExtractOp
1088 //===----------------------------------------------------------------------===//
1089 
1090 namespace {
1091 
1092 /// Canonicalizes the pattern of the form
1093 ///
1094 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1095 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1096 ///
1097 /// to
1098 ///
1099 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1100 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1102 
1103  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1104  PatternRewriter &rewriter) const final {
1105  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1106  if (!tensorCast)
1107  return failure();
1108  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1109  return failure();
1110  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1111  extract, tensorCast.getSource(), extract.getIndices());
1112  return success();
1113  }
1114 };
1115 
1116 } // namespace
1117 
1118 void ExtractOp::getAsmResultNames(
1119  function_ref<void(Value, StringRef)> setNameFn) {
1120  setNameFn(getResult(), "extracted");
1121 }
1122 
1123 LogicalResult ExtractOp::verify() {
1124  // Verify the # indices match if we have a ranked type.
1125  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1126  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1127  return emitOpError("incorrect number of indices for extract_element");
1128  return success();
1129 }
1130 
1131 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1132  // If this is a splat elements attribute, simply return the value. All of
1133  // the elements of a splat attribute are the same.
1134  if (Attribute tensor = adaptor.getTensor())
1135  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1136  return splatTensor.getSplatValue<Attribute>();
1137 
1138  // Collect the constant indices into the tensor.
1139  SmallVector<uint64_t, 8> indices;
1140  for (Attribute indice : adaptor.getIndices()) {
1141  if (!indice || !llvm::isa<IntegerAttr>(indice))
1142  return {};
1143  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1144  }
1145 
1146  // Fold extract(from_elements(...)).
1147  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1148  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1149  auto rank = tensorType.getRank();
1150  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1151  "rank mismatch");
1152  int flatIndex = 0;
1153  int stride = 1;
1154  for (int i = rank - 1; i >= 0; --i) {
1155  flatIndex += indices[i] * stride;
1156  stride *= tensorType.getDimSize(i);
1157  }
1158  // Prevent out of bounds accesses. This can happen in invalid code that
1159  // will never execute.
1160  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1161  flatIndex < 0)
1162  return {};
1163  return fromElementsOp.getElements()[flatIndex];
1164  }
1165 
1166  // If this is an elements attribute, query the value at the given indices.
1167  if (Attribute tensor = adaptor.getTensor()) {
1168  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1169  if (elementsAttr && elementsAttr.isValidIndex(indices))
1170  return elementsAttr.getValues<Attribute>()[indices];
1171  }
1172 
1173  return {};
1174 }
1175 
1176 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1177  MLIRContext *context) {
1178  results.add<ExtractFromTensorCast>(context);
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // FromElementsOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 void FromElementsOp::getAsmResultNames(
1186  function_ref<void(Value, StringRef)> setNameFn) {
1187  setNameFn(getResult(), "from_elements");
1188 }
1189 
1190 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1191  ValueRange elements) {
1192  assert(!elements.empty() && "expected at least one element");
1193  Type resultType = RankedTensorType::get(
1194  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1195  build(builder, result, resultType, elements);
1196 }
1197 
1198 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1199  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1200  return DenseElementsAttr::get(getType(), adaptor.getElements());
1201  return {};
1202 }
1203 
1204 namespace {
1205 
1206 // Pushes the index_casts that occur before extractions to after the extract.
1207 // This minimizes type conversion in some cases and enables the extract
1208 // canonicalizer. This changes:
1209 //
1210 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1211 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1212 //
1213 // to the following:
1214 //
1215 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1216 // %cast = arith.index_cast %extract : i32 to index
1217 //
1218 // to just %element.
1219 //
1220 // Consider expanding this to a template and handle all tensor cast
1221 // operations.
1222 struct ExtractElementFromIndexCast
1223  : public OpRewritePattern<tensor::ExtractOp> {
1225 
1226  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1227  PatternRewriter &rewriter) const final {
1228  Location loc = extract.getLoc();
1229  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1230  if (!indexCast)
1231  return failure();
1232 
1233  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1234 
1235  auto newExtract = rewriter.create<tensor::ExtractOp>(
1236  loc, elementTy, indexCast.getIn(), extract.getIndices());
1237 
1238  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1239  newExtract);
1240 
1241  return success();
1242  }
1243 };
1244 
1245 } // namespace
1246 
1247 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248  MLIRContext *context) {
1249  results.add<ExtractElementFromIndexCast>(context);
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // GatherOp
1254 //===----------------------------------------------------------------------===//
1255 
1256 void GatherOp::getAsmResultNames(
1257  function_ref<void(Value, StringRef)> setNameFn) {
1258  setNameFn(getResult(), "gather");
1259 }
1260 
1261 /// Return the inferred result type for a gatherOp where:
1262 /// - sourceType is the type of the source tensor gathered from
1263 /// - indicesType is the type of the indices used to gather
1264 /// - gatherDims are the dims along which the gather occurs.
1265 /// Return a full rank or ranked-reduced variant of the type depending on
1266 /// the value of rankReduced.
1267 ///
1268 /// The leading dimensions of the index tensor give the result tensor its
1269 /// leading dimensions.
1270 /// The trailing dimensions of the result tensor are obtained from the source
1271 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1272 /// rankedReduced is false), or skipping them (otherwise).
1273 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1274  RankedTensorType indicesType,
1275  ArrayRef<int64_t> gatherDims,
1276  bool rankReduced) {
1277  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1278  resultShape.reserve(resultShape.size() + sourceType.getRank());
1279  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1280  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1281  if (!rankReduced)
1282  resultShape.push_back(1);
1283  continue;
1284  }
1285  resultShape.push_back(sourceType.getDimSize(idx));
1286  }
1287  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1288 }
1289 
1290 static LogicalResult
1292  ArrayRef<int64_t> indices, int64_t rank,
1293  StringRef gatherOrScatter, StringRef sourceOrDest) {
1294  if (dims.empty())
1295  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1296 
1297  int64_t numGatherDims = dims.size();
1298  if (numGatherDims > rank)
1299  return op->emitOpError(gatherOrScatter)
1300  << "_dims overflow " << sourceOrDest << " rank";
1301  if (indices.empty() || indices.back() != numGatherDims)
1302  return op->emitOpError(gatherOrScatter)
1303  << "_dims length must match the size of last dimension of indices";
1304  for (int64_t val : dims) {
1305  if (val < 0)
1306  return op->emitOpError(gatherOrScatter)
1307  << "_dims value must be non-negative";
1308  if (val >= rank)
1309  return op->emitOpError(gatherOrScatter)
1310  << "_dims value must be smaller than " << sourceOrDest << " rank";
1311  }
1312  for (int64_t i = 1; i < numGatherDims; ++i) {
1313  if (dims[i - 1] >= dims[i])
1314  return op->emitOpError(gatherOrScatter)
1315  << "_dims values must be strictly increasing";
1316  }
1317  return success();
1318 }
1319 
1320 LogicalResult GatherOp::verify() {
1321  int64_t sourceRank = getSourceType().getRank();
1322  ArrayRef<int64_t> gatherDims = getGatherDims();
1323  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1324  getIndicesType().getShape(), sourceRank,
1325  "gather", "source")))
1326  return failure();
1327 
1328  RankedTensorType expectedResultType = GatherOp::inferResultType(
1329  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1330  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1331  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1332  if (getResultType() != expectedResultType &&
1333  getResultType() != expectedRankReducedResultType) {
1334  return emitOpError("result type "
1335  "mismatch: "
1336  "expected ")
1337  << expectedResultType << " or its rank-reduced variant "
1338  << expectedRankReducedResultType << " (got: " << getResultType()
1339  << ")";
1340  }
1341 
1342  return success();
1343 }
1344 
1345 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1346  if (OpFoldResult reshapedSource = reshapeConstantSource(
1347  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1348  getResult().getType()))
1349  return reshapedSource;
1350  return {};
1351 }
1352 
1353 //===----------------------------------------------------------------------===//
1354 // InsertOp
1355 //===----------------------------------------------------------------------===//
1356 
1357 void InsertOp::getAsmResultNames(
1358  function_ref<void(Value, StringRef)> setNameFn) {
1359  setNameFn(getResult(), "inserted");
1360 }
1361 
1362 LogicalResult InsertOp::verify() {
1363  // Verify the # indices match if we have a ranked type.
1364  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1365  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1366  return emitOpError("incorrect number of indices");
1367  return success();
1368 }
1369 
1370 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1371  Attribute scalar = adaptor.getScalar();
1372  Attribute dest = adaptor.getDest();
1373  if (scalar && dest)
1374  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1375  if (scalar == splatDest.getSplatValue<Attribute>())
1376  return dest;
1377  return {};
1378 }
1379 
1380 //===----------------------------------------------------------------------===//
1381 // GenerateOp
1382 //===----------------------------------------------------------------------===//
1383 
1384 void GenerateOp::getAsmResultNames(
1385  function_ref<void(Value, StringRef)> setNameFn) {
1386  setNameFn(getResult(), "generated");
1387 }
1388 
1389 LogicalResult GenerateOp::reifyResultShapes(
1390  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1391  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1392  int idx = 0;
1393  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1394  if (getType().isDynamicDim(dim)) {
1395  reifiedReturnShapes[0][dim] = getOperand(idx++);
1396  } else {
1397  reifiedReturnShapes[0][dim] =
1398  builder.getIndexAttr(getType().getDimSize(dim));
1399  }
1400  }
1401  return success();
1402 }
1403 
1404 LogicalResult GenerateOp::verify() {
1405  // Ensure that the tensor type has as many dynamic dimensions as are
1406  // specified by the operands.
1407  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1408  if (getNumOperands() != resultType.getNumDynamicDims())
1409  return emitError("must have as many index operands as dynamic extents "
1410  "in the result type");
1411  return success();
1412 }
1413 
1414 LogicalResult GenerateOp::verifyRegions() {
1415  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1416  // Ensure that region arguments span the index space.
1417  if (!llvm::all_of(getBody().getArgumentTypes(),
1418  [](Type ty) { return ty.isIndex(); }))
1419  return emitError("all body arguments must be index");
1420  if (getBody().getNumArguments() != resultTy.getRank())
1421  return emitError("must have one body argument per input dimension");
1422 
1423  // Ensure that the region yields an element of the right type.
1424  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1425 
1426  if (yieldOp.getValue().getType() != resultTy.getElementType())
1427  return emitOpError(
1428  "body must be terminated with a `yield` operation of the tensor "
1429  "element type");
1430 
1431  return success();
1432 }
1433 
1434 void GenerateOp::build(
1435  OpBuilder &b, OperationState &result, Type resultTy,
1436  ValueRange dynamicExtents,
1437  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1438  build(b, result, resultTy, dynamicExtents);
1439 
1440  // Build and populate body.
1441  OpBuilder::InsertionGuard guard(b);
1442  Region *bodyRegion = result.regions.front().get();
1443  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1444  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1445  SmallVector<Location, 2> argumentLocs(rank, result.location);
1446  Block *bodyBlock =
1447  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1448  bodyBuilder(b, result.location, bodyBlock->getArguments());
1449 }
1450 
1451 namespace {
1452 
1453 /// Canonicalizes tensor.generate operations with a constant
1454 /// operand into the equivalent operation with the operand expressed in the
1455 /// result type, instead. We also insert a type cast to make sure that the
1456 /// resulting IR is still well-typed.
1457 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1459 
1460  LogicalResult matchAndRewrite(GenerateOp generateOp,
1461  PatternRewriter &rewriter) const final {
1462  SmallVector<Value> foldedDynamicSizes;
1463  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1464  generateOp.getType(), generateOp.getDynamicExtents(),
1465  foldedDynamicSizes);
1466 
1467  // Stop here if no dynamic size was promoted to static.
1468  if (foldedTensorType == generateOp.getType())
1469  return failure();
1470 
1471  auto loc = generateOp.getLoc();
1472  auto newOp =
1473  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1474  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1475  newOp.getBody().begin());
1476  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1477  generateOp.getType(), newOp);
1478  return success();
1479  }
1480 };
1481 
1482 /// Canonicalizes the pattern of the form
1483 ///
1484 /// %tensor = tensor.generate %x {
1485 /// ^bb0(%arg0: index):
1486 /// <computation>
1487 /// yield %1 : index
1488 /// } : tensor<?xindex>
1489 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1490 ///
1491 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1492 /// tensor.generate operation has no side-effects.
1493 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1495 
1496  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1497  PatternRewriter &rewriter) const final {
1498  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1499  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1500  return failure();
1501 
1502  IRMapping mapping;
1503  Block *body = &tensorFromElements.getBody().front();
1504  mapping.map(body->getArguments(), extract.getIndices());
1505  for (auto &op : body->without_terminator())
1506  rewriter.clone(op, mapping);
1507 
1508  auto yield = cast<YieldOp>(body->getTerminator());
1509 
1510  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1511  return success();
1512  }
1513 };
1514 
1515 } // namespace
1516 
1517 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1518  MLIRContext *context) {
1519  // TODO: Move extract pattern to tensor::ExtractOp.
1520  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1521 }
1522 
1523 //===----------------------------------------------------------------------===//
1524 // RankOp
1525 //===----------------------------------------------------------------------===//
1526 
1527 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1528  setNameFn(getResult(), "rank");
1529 }
1530 
1531 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1532  // Constant fold rank when the rank of the operand is known.
1533  auto type = getOperand().getType();
1534  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1535  if (shapedType && shapedType.hasRank())
1536  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1537  return IntegerAttr();
1538 }
1539 
1540 //===----------------------------------------------------------------------===//
1541 // ReshapeOp
1542 //===----------------------------------------------------------------------===//
1543 
1544 void ReshapeOp::getAsmResultNames(
1545  function_ref<void(Value, StringRef)> setNameFn) {
1546  setNameFn(getResult(), "reshape");
1547 }
1548 
1549 static int64_t getNumElements(ShapedType type) {
1550  int64_t numElements = 1;
1551  for (auto dim : type.getShape())
1552  numElements *= dim;
1553  return numElements;
1554 }
1555 
1556 LogicalResult ReshapeOp::verify() {
1557  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1558  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1559 
1560  if (operandType.getElementType() != resultType.getElementType())
1561  return emitOpError("element types of source and destination tensor "
1562  "types should be the same");
1563 
1564  int64_t shapeSize =
1565  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1566  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1567  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1568 
1569  if (resultRankedType) {
1570  if (operandRankedType && resultRankedType.hasStaticShape() &&
1571  operandRankedType.hasStaticShape()) {
1572  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1573  return emitOpError("source and destination tensor should have the "
1574  "same number of elements");
1575  }
1576  if (ShapedType::isDynamic(shapeSize))
1577  return emitOpError("cannot use shape operand with dynamic length to "
1578  "reshape to statically-ranked tensor type");
1579  if (shapeSize != resultRankedType.getRank())
1580  return emitOpError(
1581  "length of shape operand differs from the result's tensor rank");
1582  }
1583  return success();
1584 }
1585 
1586 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1587  if (OpFoldResult reshapedSource = reshapeConstantSource(
1588  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1589  getResult().getType()))
1590  return reshapedSource;
1591 
1592  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1593  // producer's input instead as the original tensor to reshape. This could
1594  // render such producer dead code.
1595  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1596  getSourceMutable().assign(reshapeOpProducer.getSource());
1597  return getResult();
1598  }
1599 
1600  auto source = getSource();
1601  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1602  auto resultTy = dyn_cast<RankedTensorType>(getType());
1603  if (!sourceTy || !resultTy || sourceTy != resultTy)
1604  return {};
1605 
1606  // If the source and result are both 1D tensors and have the same type, the
1607  // reshape has no effect, even if the tensor is dynamically shaped.
1608  if (sourceTy.getRank() == 1)
1609  return source;
1610 
1611  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1612  auto elements = fromElements.getElements();
1613  bool dynamicNoop =
1614  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1615  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1616  auto element = elements[id];
1617 
1618  if (auto cst = getConstantIntValue(element)) {
1619  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1620  continue;
1621  }
1622 
1623  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1624  dynamicNoop &= dimOp.getSource() == source;
1625 
1626  APSInt dim;
1627  auto cst = getConstantIntValue(dimOp.getIndex());
1628  dynamicNoop &=
1629  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1630  continue;
1631  }
1632 
1633  dynamicNoop = false;
1634  break;
1635  }
1636 
1637  if (dynamicNoop)
1638  return source;
1639  }
1640 
1641  return {};
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // Reassociative reshape ops
1646 //===----------------------------------------------------------------------===//
1647 
1648 void CollapseShapeOp::getAsmResultNames(
1649  function_ref<void(Value, StringRef)> setNameFn) {
1650  setNameFn(getResult(), "collapsed");
1651 }
1652 
1653 void ExpandShapeOp::getAsmResultNames(
1654  function_ref<void(Value, StringRef)> setNameFn) {
1655  setNameFn(getResult(), "expanded");
1656 }
1657 
1658 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1659  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1660  "invalid resultDim");
1661  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1662  if (llvm::is_contained(it.value(), resultDim))
1663  return it.index();
1664  llvm_unreachable("could not find reassociation group");
1665 }
1666 
1667 FailureOr<SmallVector<OpFoldResult>>
1668 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1669  RankedTensorType expandedType,
1670  ArrayRef<ReassociationIndices> reassociation,
1671  ArrayRef<OpFoldResult> inputShape) {
1672  std::optional<SmallVector<OpFoldResult>> outputShape =
1673  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1674  inputShape);
1675  if (!outputShape)
1676  return failure();
1677  return *outputShape;
1678 }
1679 
1680 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1681  Type resultType, Value src,
1682  ArrayRef<ReassociationIndices> reassociation,
1683  ArrayRef<OpFoldResult> outputShape) {
1684  auto [staticOutputShape, dynamicOutputShape] =
1686  build(builder, result, cast<RankedTensorType>(resultType), src,
1687  getReassociationIndicesAttribute(builder, reassociation),
1688  dynamicOutputShape, staticOutputShape);
1689 }
1690 
1691 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1692  Type resultType, Value src,
1693  ArrayRef<ReassociationIndices> reassociation) {
1694  SmallVector<OpFoldResult> inputShape =
1695  getMixedSizes(builder, result.location, src);
1696  auto tensorResultTy = cast<RankedTensorType>(resultType);
1697  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1698  builder, result.location, tensorResultTy, reassociation, inputShape);
1699  SmallVector<OpFoldResult> outputShapeOrEmpty;
1700  if (succeeded(outputShape)) {
1701  outputShapeOrEmpty = *outputShape;
1702  }
1703  build(builder, result, tensorResultTy, src, reassociation,
1704  outputShapeOrEmpty);
1705 }
1706 
1707 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1708  return getSymbolLessAffineMaps(getReassociationExprs());
1709 }
1710 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1712  getReassociationIndices());
1713 }
1714 
1715 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1716  return getSymbolLessAffineMaps(getReassociationExprs());
1717 }
1718 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1720  getReassociationIndices());
1721 }
1722 
1723 RankedTensorType CollapseShapeOp::inferCollapsedType(
1724  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1725  return inferCollapsedType(
1727  type.getContext(), reassociation)));
1728 }
1729 
1730 /// Compute the RankedTensorType obtained by applying `reassociation` to
1731 /// `type`.
1732 RankedTensorType
1733 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1734  ArrayRef<AffineMap> reassociation) {
1735  auto shape = type.getShape();
1736  SmallVector<int64_t, 4> newShape;
1737  newShape.reserve(reassociation.size());
1738 
1739  // Use the fact that reassociation is valid to simplify the logic: only use
1740  // each map's rank.
1741  assert(isReassociationValid(reassociation) && "invalid reassociation");
1742  unsigned currentDim = 0;
1743  for (AffineMap m : reassociation) {
1744  unsigned dim = m.getNumResults();
1745  auto band = shape.slice(currentDim, dim);
1746  int64_t size = 1;
1747  if (llvm::is_contained(band, ShapedType::kDynamic))
1748  size = ShapedType::kDynamic;
1749  else
1750  for (unsigned d = 0; d < dim; ++d)
1751  size *= shape[currentDim + d];
1752  newShape.push_back(size);
1753  currentDim += dim;
1754  }
1755 
1756  return RankedTensorType::get(newShape, type.getElementType());
1757 }
1758 
1759 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1760  ArrayRef<ReassociationIndices> reassociation,
1761  ArrayRef<NamedAttribute> attrs) {
1762  auto resultType = inferCollapsedType(
1763  llvm::cast<RankedTensorType>(src.getType()),
1765  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1766  result.addAttribute(getReassociationAttrStrName(),
1767  getReassociationIndicesAttribute(b, reassociation));
1768  build(b, result, resultType, src, attrs);
1769 }
1770 
1771 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1772  TensorReshapeOp, ExpandShapeOp>::value>
1773 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1774  RankedTensorType expandedType,
1775  RankedTensorType collapsedType) {
1776  if (failed(
1777  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1778  return failure();
1779 
1780  auto maps = op.getReassociationMaps();
1781  RankedTensorType expectedType =
1782  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1783  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1784  return op.emitOpError("expected collapsed type to be ")
1785  << expectedType << ", but got " << collapsedType;
1786  return success();
1787 }
1788 
1789 LogicalResult ExpandShapeOp::verify() {
1790  auto srcType = getSrcType();
1791  auto resultType = getResultType();
1792 
1793  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1794  return emitOpError("expected number of static shape dims to be equal to "
1795  "the output rank (")
1796  << resultType.getRank() << ") but found "
1797  << getStaticOutputShape().size() << " inputs instead";
1798 
1799  if ((int64_t)getOutputShape().size() !=
1800  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1801  return emitOpError("mismatch in dynamic dims in output_shape and "
1802  "static_output_shape: static_output_shape has ")
1803  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1804  << " dynamic dims while output_shape has " << getOutputShape().size()
1805  << " values";
1806 
1807  return verifyTensorReshapeOp(*this, resultType, srcType);
1808 }
1809 
1810 LogicalResult CollapseShapeOp::verify() {
1811  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1812 }
1813 
1814 namespace {
1815 /// Reshape of a splat constant can be replaced with a constant of the result
1816 /// type.
1817 template <typename TensorReshapeOp>
1818 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1820  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1821  PatternRewriter &rewriter) const override {
1822  DenseElementsAttr attr;
1823  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1824  return failure();
1825  if (!attr || !attr.isSplat())
1826  return failure();
1828  reshapeOp.getResultType(), attr.getRawData());
1829  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1830  return success();
1831  }
1832 };
1833 
1834 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1835 template <typename TensorReshapeOp>
1836 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1837 public:
1839 
1840  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1841  PatternRewriter &rewriter) const override {
1842  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1843  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1844  return failure();
1845 
1846  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1847  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1848  return success();
1849  }
1850 };
1851 
1852 /// Reshape of a FromElements can be replaced with a FromElements of the
1853 /// result type
1854 template <typename TensorReshapeOp>
1855 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1857  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1858  PatternRewriter &rewriter) const override {
1859  auto fromElements =
1860  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1861  if (!fromElements)
1862  return failure();
1863 
1864  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1865 
1866  if (!shapedTy.hasStaticShape())
1867  return failure();
1868 
1869  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1870  fromElements.getElements());
1871  return success();
1872  }
1873 };
1874 
1875 // Fold CastOp into CollapseShapeOp when adding static information.
1876 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1878 
1879  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1880  PatternRewriter &rewriter) const override {
1881  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1882  if (!tensor::canFoldIntoConsumerOp(castOp))
1883  return failure();
1884 
1885  RankedTensorType srcType =
1886  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1887  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1888  srcType, collapseShapeOp.getReassociationMaps());
1889 
1890  if (newResultType == collapseShapeOp.getResultType()) {
1891  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1892  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1893  });
1894  } else {
1895  auto newOp = rewriter.create<CollapseShapeOp>(
1896  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1897  collapseShapeOp.getReassociation());
1898  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1899  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1900  }
1901  return success();
1902  }
1903 };
1904 
1905 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1907 
1908  LogicalResult matchAndRewrite(DimOp dimOp,
1909  PatternRewriter &rewriter) const override {
1910  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1911  if (!expandShapeOp)
1912  return failure();
1913 
1914  // Only constant dimension values are supported.
1915  std::optional<int64_t> dim = dimOp.getConstantIndex();
1916  if (!dim.has_value())
1917  return failure();
1918 
1919  // Skip static dims. These are folded to constant ops.
1920  RankedTensorType resultType = expandShapeOp.getResultType();
1921  if (!resultType.isDynamicDim(*dim))
1922  return failure();
1923 
1924  // Find reassociation group that contains this result dimension.
1925  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1926 
1927  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1928  // ExpandShapeOp would be ambiguous.)
1929  int64_t product = 1;
1930  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1931  for (int64_t d : grp) {
1932  if (d != dim) {
1933  assert(!resultType.isDynamicDim(d) && "expected static dim");
1934  product *= resultType.getDimSize(d);
1935  }
1936  }
1937 
1938  // result dim size = src dim size / (product(other dims in reassoc group))
1939  Value srcDimSz =
1940  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1941  AffineExpr expr;
1942  bindSymbols(dimOp.getContext(), expr);
1943  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1944  dimOp, expr.floorDiv(product), srcDimSz);
1945  return success();
1946  }
1947 };
1948 
1949 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1951 
1952  LogicalResult matchAndRewrite(DimOp dimOp,
1953  PatternRewriter &rewriter) const override {
1954  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1955  if (!collapseShapeOp)
1956  return failure();
1957 
1958  // Only constant dimension values are supported.
1959  std::optional<int64_t> dim = dimOp.getConstantIndex();
1960  if (!dim.has_value())
1961  return failure();
1962 
1963  // Skip static dims. These are folded to constant ops.
1964  RankedTensorType resultType = collapseShapeOp.getResultType();
1965  if (!resultType.isDynamicDim(*dim))
1966  return failure();
1967 
1968  // Get reassociation group of the result dimension.
1969  ReassociationIndices group =
1970  collapseShapeOp.getReassociationIndices()[*dim];
1971 
1972  // result dim size = product(dims in reassoc group)
1973  SmallVector<Value> srcDimSizes;
1976  for (const auto &it : llvm::enumerate(group)) {
1977  srcDimSizes.push_back(rewriter.create<DimOp>(
1978  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1979  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1980  product = product ? product * syms.back() : syms.back();
1981  }
1982  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1983  srcDimSizes);
1984  return success();
1985  }
1986 };
1987 } // namespace
1988 
1989 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1990  MLIRContext *context) {
1991  results.add<
1994  FoldReshapeWithConstant<ExpandShapeOp>,
1995  FoldReshapeWithSplat<ExpandShapeOp>,
1996  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1997  FoldDimOfCollapseShape>(context);
1998 }
1999 
2000 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2001  MLIRContext *context) {
2002  results.add<
2004  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2005  tensor::DimOp, RankedTensorType>,
2006  FoldReshapeWithConstant<CollapseShapeOp>,
2007  FoldReshapeWithSplat<CollapseShapeOp>,
2008  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2009  context);
2010 }
2011 
2012 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2013  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2014  adaptor.getOperands());
2015 }
2016 
2017 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2018  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2019  adaptor.getOperands());
2020 }
2021 
2022 //===----------------------------------------------------------------------===//
2023 // ExtractSliceOp
2024 //===----------------------------------------------------------------------===//
2025 
2026 void ExtractSliceOp::getAsmResultNames(
2027  function_ref<void(Value, StringRef)> setNameFn) {
2028  setNameFn(getResult(), "extracted_slice");
2029 }
2030 
2031 /// An extract_slice result type can be inferred, when it is not
2032 /// rank-reduced, from the source type and the static representation of
2033 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2034 RankedTensorType ExtractSliceOp::inferResultType(
2035  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2036  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2037  // An extract_slice op may specify only a leading subset of offset/sizes/
2038  // strides in which case we complete with offset=0, sizes from memref type
2039  // and strides=1.
2040  assert(static_cast<int64_t>(staticSizes.size()) ==
2041  sourceTensorType.getRank() &&
2042  "unexpected staticSizes not equal to rank of source");
2043  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2044  sourceTensorType.getEncoding());
2045 }
2046 
2047 RankedTensorType ExtractSliceOp::inferResultType(
2048  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2050  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2051  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2052  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2053  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2054  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2055  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2056  staticSizes, staticStrides);
2057 }
2058 
2059 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2060 /// number of sizes), drop as many size 1 as needed to produce an inferred
2061 /// type with the desired rank.
2062 ///
2063 /// Note that there may be multiple ways to compute this rank-reduced type:
2064 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2065 ///
2066 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2067 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2068  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2069  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2070  ArrayRef<int64_t> strides) {
2071  // Type inferred in the absence of rank-reducing behavior.
2072  auto inferredType = llvm::cast<RankedTensorType>(
2073  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2074  int rankDiff = inferredType.getRank() - desiredResultRank;
2075  if (rankDiff > 0) {
2076  auto shape = inferredType.getShape();
2077  llvm::SmallBitVector dimsToProject =
2078  getPositionsOfShapeOne(rankDiff, shape);
2079  SmallVector<int64_t> projectedShape;
2080  // Best effort rank-reducing: drop 1s in order.
2081  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2082  if (!dimsToProject.test(pos))
2083  projectedShape.push_back(shape[pos]);
2084  inferredType =
2085  RankedTensorType::get(projectedShape, inferredType.getElementType());
2086  }
2087  return inferredType;
2088 }
2089 
2090 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2091  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2093  ArrayRef<OpFoldResult> strides) {
2094  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2095  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2096  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2097  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2098  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2099  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2100  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2101  staticStrides);
2102 }
2103 
2104 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2105 /// result type. If the type passed is nullptr, it is inferred.
2106 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2107  RankedTensorType resultType, Value source,
2108  ArrayRef<OpFoldResult> offsets,
2109  ArrayRef<OpFoldResult> sizes,
2110  ArrayRef<OpFoldResult> strides,
2111  ArrayRef<NamedAttribute> attrs) {
2112  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2113  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2114  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2115  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2116  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2117  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2118  // Structuring implementation this way avoids duplication between builders.
2119  if (!resultType) {
2120  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2121  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2122  }
2123  result.addAttributes(attrs);
2124  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2125  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2126  b.getDenseI64ArrayAttr(staticSizes),
2127  b.getDenseI64ArrayAttr(staticStrides));
2128 }
2129 
2130 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2131 /// result type.
2132 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2133  ArrayRef<OpFoldResult> offsets,
2134  ArrayRef<OpFoldResult> sizes,
2135  ArrayRef<OpFoldResult> strides,
2136  ArrayRef<NamedAttribute> attrs) {
2137  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2138 }
2139 
2140 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2141 /// a Range vector.
2142 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2143  ArrayRef<Range> ranges,
2144  ArrayRef<NamedAttribute> attrs) {
2145  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2146  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2147 }
2148 
2149 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2150 /// the type passed is nullptr, it is inferred.
2151 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2152  RankedTensorType resultType, Value source,
2153  ValueRange offsets, ValueRange sizes,
2154  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2155  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2156  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2157  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2158  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2159  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2160  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2161  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2162 }
2163 
2164 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2165 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2166  ValueRange offsets, ValueRange sizes,
2167  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2168  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2169 }
2170 
2171 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2172  Operation *op,
2173  RankedTensorType expectedType) {
2174  switch (result) {
2176  return success();
2178  return op->emitError("expected rank to be smaller or equal to ")
2179  << "the other rank. ";
2181  return op->emitError("expected type to be ")
2182  << expectedType << " or a rank-reduced version. (size mismatch) ";
2184  return op->emitError("expected element type to be ")
2185  << expectedType.getElementType();
2186  default:
2187  llvm_unreachable("unexpected extract_slice op verification result");
2188  }
2189 }
2190 
2191 /// Verifier for ExtractSliceOp.
2192 LogicalResult ExtractSliceOp::verify() {
2193  // Verify result type against inferred type.
2194  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2195  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2196  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2197  return produceSliceErrorMsg(result, *this, expectedType);
2198 }
2199 
2200 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2202 }
2203 
2204 FailureOr<Value>
2205 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2206  ArrayRef<int64_t> desiredShape) {
2207  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2208  assert(sourceTensorType && "not a ranked tensor type");
2209  auto sourceShape = sourceTensorType.getShape();
2210  if (sourceShape.equals(desiredShape))
2211  return value;
2212  auto maybeRankReductionMask =
2213  mlir::computeRankReductionMask(sourceShape, desiredShape);
2214  if (!maybeRankReductionMask)
2215  return failure();
2217  b, loc, value,
2218  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2219 }
2220 
2221 LogicalResult ExtractSliceOp::reifyResultShapes(
2222  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2223  reifiedReturnShapes.resize(1);
2224  reifiedReturnShapes[0].reserve(getType().getRank());
2226  llvm::SmallBitVector droppedDims = getDroppedDims();
2227  for (const auto &size : enumerate(mixedSizes)) {
2228  if (droppedDims.test(size.index()))
2229  continue;
2230  reifiedReturnShapes[0].push_back(size.value());
2231  }
2232  return success();
2233 }
2234 
2235 namespace {
2236 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2237 /// This essentially pushes memref_cast past its consuming slice when
2238 /// `canFoldIntoConsumerOp` is true.
2239 ///
2240 /// Example:
2241 /// ```
2242 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2243 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2244 /// tensor<3x4xf32>
2245 /// ```
2246 /// is rewritten into:
2247 /// ```
2248 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2249 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2250 /// ```
2251 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2252 public:
2254 
2255  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2256  PatternRewriter &rewriter) const override {
2257  // Any constant operand, just return to let the constant folder kick in.
2258  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2259  return matchPattern(operand, matchConstantIndex());
2260  }))
2261  return failure();
2262 
2263  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2264  if (!castOp)
2265  return failure();
2266 
2267  if (!canFoldIntoConsumerOp(castOp))
2268  return failure();
2269 
2270  // Create folded extract.
2271  Location loc = sliceOp.getLoc();
2272  Value newResult = rewriter.create<ExtractSliceOp>(
2273  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2274  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2275  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2276  if (newResult.getType() != sliceOp.getType())
2277  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2278  rewriter.replaceOp(sliceOp, newResult);
2279  return success();
2280  }
2281 };
2282 
2283 /// Slice elements from `values` into `outValues`. `counts` represents the
2284 /// numbers of elements to stride in the original values for each dimension.
2285 /// The output values can be used to construct a DenseElementsAttr.
2286 template <typename IterTy, typename ElemTy>
2287 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2288  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2289  ArrayRef<int64_t> strides,
2290  llvm::SmallVectorImpl<ElemTy> *outValues) {
2291  assert(offsets.size() == sizes.size());
2292  assert(offsets.size() == strides.size());
2293  if (offsets.empty())
2294  return;
2295 
2296  int64_t offset = offsets.front();
2297  int64_t size = sizes.front();
2298  int64_t stride = strides.front();
2299  if (offsets.size() == 1) {
2300  for (int64_t i = 0; i < size; ++i, offset += stride)
2301  outValues->push_back(*(values + offset));
2302 
2303  return;
2304  }
2305 
2306  for (int64_t i = 0; i < size; ++i, offset += stride) {
2307  auto begin = values + offset * counts.front();
2308  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2309  offsets.drop_front(), sizes.drop_front(),
2310  strides.drop_front(), outValues);
2311  }
2312 }
2313 
2314 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2315 /// folded operation might introduce more constant data; Users can control
2316 /// their heuristics by the control function.
2317 class ConstantOpExtractSliceFolder final
2318  : public OpRewritePattern<ExtractSliceOp> {
2319 public:
2321 
2322  ConstantOpExtractSliceFolder(MLIRContext *context,
2324  : OpRewritePattern<ExtractSliceOp>(context),
2325  controlFn(std::move(controlFn)) {}
2326 
2327  LogicalResult matchAndRewrite(ExtractSliceOp op,
2328  PatternRewriter &rewriter) const override {
2329  DenseElementsAttr attr;
2330  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2331  return failure();
2332 
2333  // A constant splat is handled by fold().
2334  if (attr.isSplat())
2335  return failure();
2336 
2337  // Dynamic result shape is not supported.
2338  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2339  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2340  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2341  return failure();
2342 
2343  // Customized control over the folding.
2344  if (!controlFn(op))
2345  return failure();
2346 
2347  int64_t count = sourceType.getNumElements();
2348  if (count == 0)
2349  return failure();
2350 
2351  // Check if there are any dynamic parts, which are not supported.
2352  auto offsets = op.getStaticOffsets();
2353  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2354  return failure();
2355  auto sizes = op.getStaticSizes();
2356  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2357  return failure();
2358  auto strides = op.getStaticStrides();
2359  if (llvm::is_contained(strides, ShapedType::kDynamic))
2360  return failure();
2361 
2362  // Compute the stride for each dimension.
2363  SmallVector<int64_t> counts;
2364  ArrayRef<int64_t> shape = sourceType.getShape();
2365  counts.reserve(shape.size());
2366  for (int64_t v : shape) {
2367  count = count / v;
2368  counts.push_back(count);
2369  }
2370 
2371  // New attribute constructed by the sliced values.
2372  DenseElementsAttr newAttr;
2373 
2374  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2375  SmallVector<APInt> outValues;
2376  outValues.reserve(sourceType.getNumElements());
2377  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2378  elems.begin(), counts, offsets, sizes, strides, &outValues);
2379  newAttr = DenseElementsAttr::get(resultType, outValues);
2380  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2381  SmallVector<APFloat> outValues;
2382  outValues.reserve(sourceType.getNumElements());
2383  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2384  elems.begin(), counts, offsets, sizes, strides, &outValues);
2385  newAttr = DenseElementsAttr::get(resultType, outValues);
2386  }
2387 
2388  if (newAttr) {
2389  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2390  return success();
2391  }
2392 
2393  return failure();
2394  }
2395 
2396 private:
2397  /// This additionally controls whether the fold happens or not. Users can
2398  /// impose their heuristics in the function.
2400 };
2401 
2402 } // namespace
2403 
2405  RewritePatternSet &patterns,
2406  const ControlConstantExtractSliceFusionFn &controlFn) {
2407  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2408 }
2409 
2410 /// Return the canonical type of the result of an extract_slice op.
2412  RankedTensorType operator()(ExtractSliceOp op,
2413  ArrayRef<OpFoldResult> mixedOffsets,
2414  ArrayRef<OpFoldResult> mixedSizes,
2415  ArrayRef<OpFoldResult> mixedStrides) {
2416  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2417  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2418  mixedStrides);
2419  }
2420 };
2421 
2422 /// A canonicalizer wrapper to replace ExtractSliceOps.
2424  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2425  ExtractSliceOp newOp) {
2426  Value replacement = newOp.getResult();
2427  if (replacement.getType() != op.getType())
2428  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2429  replacement);
2430  rewriter.replaceOp(op, replacement);
2431  }
2432 };
2433 
2434 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2435  MLIRContext *context) {
2436  results.add<
2439  ExtractSliceOpCastFolder>(context);
2440 }
2441 
2442 //
2443 static LogicalResult
2444 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2445  ShapedType shapedType) {
2446  OpBuilder b(op.getContext());
2447  for (OpFoldResult ofr : op.getMixedOffsets())
2448  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2449  return failure();
2450  // Rank-reducing noops only need to inspect the leading dimensions:
2451  // llvm::zip is appropriate.
2452  auto shape = shapedType.getShape();
2453  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2454  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2455  return failure();
2456  for (OpFoldResult ofr : op.getMixedStrides())
2457  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2458  return failure();
2459  return success();
2460 }
2461 
2462 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2463 /// slice, we can return the InsertSliceOp's source directly.
2464 // TODO: This only checks the immediate producer; extend to go up the
2465 // insert/extract chain if the slices are disjoint.
2466 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2467  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2468 
2469  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2470  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2471  insertOp.isSameAs(extractOp, isSame))
2472  return insertOp.getSource();
2473 
2474  return {};
2475 }
2476 
2477 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2478  if (OpFoldResult reshapedSource = reshapeConstantSource(
2479  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2480  getResult().getType()))
2481  return reshapedSource;
2482  if (getSourceType() == getType() &&
2484  return this->getSource();
2485  if (Value slice = foldExtractAfterInsertSlice(*this))
2486  return slice;
2487 
2488  return OpFoldResult();
2489 }
2490 
2492  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2493  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2494  unsigned rank = rankedTensorType.getRank();
2495  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2496  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2497  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2498  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2499  offsets, sizes, strides);
2500 }
2501 
2502 //===----------------------------------------------------------------------===//
2503 // InsertSliceOp
2504 //===----------------------------------------------------------------------===//
2505 
2506 void InsertSliceOp::getAsmResultNames(
2507  function_ref<void(Value, StringRef)> setNameFn) {
2508  setNameFn(getResult(), "inserted_slice");
2509 }
2510 
2511 // Build a InsertSliceOp with mixed static and dynamic entries.
2512 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2513  Value dest, ArrayRef<OpFoldResult> offsets,
2514  ArrayRef<OpFoldResult> sizes,
2515  ArrayRef<OpFoldResult> strides,
2516  ArrayRef<NamedAttribute> attrs) {
2517  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2518  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2519  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2520  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2521  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2522  result.addAttributes(attrs);
2523  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2524  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2525  b.getDenseI64ArrayAttr(staticSizes),
2526  b.getDenseI64ArrayAttr(staticStrides));
2527 }
2528 
2529 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2530 /// Range vector.
2531 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2532  Value dest, ArrayRef<Range> ranges,
2533  ArrayRef<NamedAttribute> attrs) {
2534  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2535  build(b, result, source, dest, offsets, sizes, strides, attrs);
2536 }
2537 
2538 // Build a InsertSliceOp with dynamic entries.
2539 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2540  Value dest, ValueRange offsets, ValueRange sizes,
2541  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2542  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2543  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2544  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2545  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2546  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2547  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2548  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2549 }
2550 
2551 /// Rank-reducing type verification for both InsertSliceOp and
2552 /// ParallelInsertSliceOp.
2554  RankedTensorType srcType, RankedTensorType dstType,
2555  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2556  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2557  // insert_slice is the inverse of extract_slice, use the same type
2558  // inference.
2559  RankedTensorType expected = ExtractSliceOp::inferResultType(
2560  dstType, staticOffsets, staticSizes, staticStrides);
2561  if (expectedType)
2562  *expectedType = expected;
2563  return isRankReducedType(expected, srcType);
2564 }
2565 
2566 /// Verifier for InsertSliceOp.
2567 LogicalResult InsertSliceOp::verify() {
2568  RankedTensorType expectedType;
2569  SliceVerificationResult result =
2570  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2571  getStaticSizes(), getStaticStrides(), &expectedType);
2572  return produceSliceErrorMsg(result, *this, expectedType);
2573 }
2574 
2575 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2576 /// can mutate the second InsertSliceOp's destination to the first one's.
2577 ///
2578 /// Example:
2579 ///
2580 /// ```mlir
2581 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2582 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2583 /// ```
2584 ///
2585 /// folds into:
2586 ///
2587 /// ```mlir
2588 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2589 /// ```
2590 ///
2591 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2592 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2593  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2594 
2595  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2596  if (!prevInsertOp ||
2597  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2598  !prevInsertOp.isSameAs(insertOp, isSame))
2599  return failure();
2600 
2601  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2602  return success();
2603 }
2604 
2605 /// Folds round-trip extract/insert slice op pairs.
2606 /// Example:
2607 /// ```mlir
2608 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2609 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2610 /// ```
2611 /// can be folded into %val.
2612 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2613  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2614 
2615  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2616  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2617  !extractOp.isSameAs(insertOp, isSame))
2618  return nullptr;
2619 
2620  return extractOp.getSource();
2621 }
2622 
2623 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2624  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2625  getSourceType() == getType() &&
2627  return this->getSource();
2628  if (succeeded(foldInsertAfterInsertSlice(*this)))
2629  return getResult();
2630  if (auto result = foldInsertAfterExtractSlice(*this))
2631  return result;
2632  if (llvm::any_of(getMixedSizes(),
2633  [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2634  return getDest();
2635  return OpFoldResult();
2636 }
2637 
2638 LogicalResult InsertSliceOp::reifyResultShapes(
2639  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2640  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2641  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2642  return success();
2643 }
2644 
2645 namespace {
2646 /// Pattern to rewrite a insert_slice op with constant arguments.
2647 ///
2648 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2649 template <typename InsertOpTy>
2650 class InsertSliceOpConstantArgumentFolder final
2651  : public OpRewritePattern<InsertOpTy> {
2652 public:
2654 
2655  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2656  PatternRewriter &rewriter) const override {
2657  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2658  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2659  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2660 
2661  // No constant operands were folded, just return;
2662  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2663  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2664  failed(foldDynamicStrideList(mixedStrides)))
2665  return failure();
2666 
2667  // Create the new op in canonical form.
2668  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2669  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2670  mixedOffsets, mixedSizes, mixedStrides);
2671  Value toInsert = insertSliceOp.getSource();
2672  if (sourceType != insertSliceOp.getSourceType()) {
2673  OpBuilder::InsertionGuard g(rewriter);
2674  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2675  // is that the insertion point is just before the ParallelCombiningOp in
2676  // the parallel case.
2677  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2678  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2679  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2680  sourceType, toInsert);
2681  }
2682  rewriter.replaceOpWithNewOp<InsertOpTy>(
2683  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2684  mixedSizes, mixedStrides);
2685  return success();
2686  }
2687 };
2688 
2689 /// Fold tensor_casts with insert_slice operations. If the source or
2690 /// destination tensor is a tensor_cast that removes static type information,
2691 /// the cast is folded into the insert_slice operation. E.g.:
2692 ///
2693 /// ```mlir
2694 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2695 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2696 /// ```
2697 ///
2698 /// folds into:
2699 ///
2700 /// ```mlir
2701 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2702 /// ```
2703 ///
2704 /// Note: When folding a cast on the destination tensor, the result of the
2705 /// insert_slice operation is casted to ensure that the type of the result did
2706 /// not change.
2707 ///
2708 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2709 template <typename InsertOpTy>
2710 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2712 
2713  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2714  PatternRewriter &rewriter) const override {
2715  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2716  return matchPattern(operand, matchConstantIndex());
2717  }))
2718  return failure();
2719 
2720  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2721  auto castOp = v.getDefiningOp<tensor::CastOp>();
2722  if (!castOp || !canFoldIntoConsumerOp(castOp))
2723  return std::nullopt;
2724  return castOp.getSource();
2725  };
2726  std::optional<Value> sourceCastSource =
2727  getSourceOfCastOp(insertSliceOp.getSource());
2728  std::optional<Value> destCastSource =
2729  getSourceOfCastOp(insertSliceOp.getDest());
2730  if (!sourceCastSource && !destCastSource)
2731  return failure();
2732 
2733  auto src =
2734  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2735  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2736  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2737  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2738  if (!srcType || !dstType)
2739  return failure();
2740 
2741  // The tensor.cast source could have additional static information not seen
2742  // in the insert slice op static sizes, so we ignore dynamic dims when
2743  // computing the rank reduction mask.
2744  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2745  auto rankReductionMask = computeRankReductionMask(
2746  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2747  if (!rankReductionMask.has_value())
2748  return failure();
2749  // Replace dimensions in the insert slice op with corresponding static dims
2750  // from the cast source type. If the insert slice sizes have static dims
2751  // that are not static in the tensor.cast source (i.e., when the cast op
2752  // casts a dynamic dim to static), the dim should not be replaced, and the
2753  // pattern will fail later in `verifyInsertSliceOp`.
2754  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2755  int64_t rankReducedIdx = 0;
2756  for (auto [idx, size] : enumerate(staticSizes)) {
2757  if (!rankReductionMask.value().contains(idx) &&
2758  !srcType.isDynamicDim(rankReducedIdx)) {
2759  mixedSizes[idx] = getAsIndexOpFoldResult(
2760  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2761  size = srcType.getDimSize(rankReducedIdx++);
2762  }
2763  }
2764  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2765  staticSizes, insertSliceOp.getStaticStrides()) !=
2767  return failure();
2768 
2769  Operation *replacement = rewriter.create<InsertOpTy>(
2770  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2771  mixedSizes, insertSliceOp.getMixedStrides());
2772 
2773  // In the parallel case there is no result and so nothing to cast.
2774  bool isParallelInsert =
2775  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2776  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2777  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2778  insertSliceOp.getDestType(),
2779  replacement->getResult(0));
2780  }
2781  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2782  return success();
2783  }
2784 };
2785 
2786 /// If additional static type information can be deduced from a insert_slice's
2787 /// size operands, insert an explicit cast of the op's source operand. This
2788 /// enables other canonicalization patterns that are matching for tensor_cast
2789 /// ops such as `ForOpTensorCastFolder` in SCF.
2790 ///
2791 /// Example:
2792 ///
2793 /// ```mlir
2794 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2795 /// : tensor<?x?xf32> into ...
2796 /// ```
2797 ///
2798 /// folds into:
2799 ///
2800 /// ```mlir
2801 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2802 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2803 /// : tensor<64x64xf32> into ...
2804 /// ```
2805 ///
2806 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2807 template <typename InsertOpTy>
2808 struct InsertSliceOpSourceCastInserter final
2809  : public OpRewritePattern<InsertOpTy> {
2811 
2812  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2813  PatternRewriter &rewriter) const override {
2814  RankedTensorType srcType = insertSliceOp.getSourceType();
2815  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2816  return failure();
2817  SmallVector<int64_t> newSrcShape(srcType.getShape());
2818  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2819  if (std::optional<int64_t> constInt =
2820  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2821  // Bail on invalid IR.
2822  if (*constInt < 0)
2823  return failure();
2824  newSrcShape[i] = *constInt;
2825  }
2826  }
2827  if (!hasValidSizesOffsets(newSrcShape))
2828  return failure();
2829 
2830  RankedTensorType newSrcType = RankedTensorType::get(
2831  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2832  if (srcType == newSrcType ||
2833  !preservesStaticInformation(srcType, newSrcType) ||
2834  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2835  return failure();
2836 
2837  // newSrcType is:
2838  // 1) Different from srcType.
2839  // 2) "More static" than srcType.
2840  // 3) Cast-compatible with srcType.
2841  // Insert the cast.
2842  OpBuilder::InsertionGuard g(rewriter);
2843  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2844  // that the insertion point is just before the ParallelCombiningOp in the
2845  // parallel case.
2846  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2847  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2848  Value cast = rewriter.create<tensor::CastOp>(
2849  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2850  rewriter.replaceOpWithNewOp<InsertOpTy>(
2851  insertSliceOp, cast, insertSliceOp.getDest(),
2852  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2853  insertSliceOp.getMixedStrides());
2854  return success();
2855  }
2856 };
2857 } // namespace
2858 
2859 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2860  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2861 }
2862 
2863 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2864  MLIRContext *context) {
2865  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2866  InsertSliceOpCastFolder<InsertSliceOp>,
2867  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2868 }
2869 
2871  Location loc,
2872  Value tensor,
2873  Value dest) {
2874  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2875  unsigned rank = rankedTensorType.getRank();
2876  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2877  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2878  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2879  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2880  sizes, strides);
2881 }
2882 
2883 //===----------------------------------------------------------------------===//
2884 // PadOp
2885 //===----------------------------------------------------------------------===//
2886 
2887 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2888  setNameFn(getResult(), "padded");
2889 }
2890 
2891 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2892 // supports optional types.
2893 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2894  Type typeToInfer, Type typeToInferFrom) {}
2895 
2896 ParseResult
2898  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2899  Type &typeToInfer, Type typeToInferFrom) {
2900  if (optOperand)
2901  typeToInfer = typeToInferFrom;
2902  return success();
2903 }
2904 
2905 LogicalResult PadOp::verify() {
2906  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2907  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2908  auto expectedType =
2909  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2910  if (!expectedType) {
2911  return emitError("failed to infer expectedType from sourceType ")
2912  << sourceType << ", specified resultType is " << resultType;
2913  }
2914  if (resultType.getRank() != expectedType.getRank()) {
2915  return emitError("specified type ")
2916  << resultType << " does not match the inferred type "
2917  << expectedType;
2918  }
2919  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2920  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2921  continue;
2922  if (expectedType.isDynamicDim(i))
2923  continue;
2924  return emitError("specified type ")
2925  << resultType << " does not match the inferred type "
2926  << expectedType;
2927  }
2928 
2929  return success();
2930 }
2931 
2932 LogicalResult PadOp::verifyRegions() {
2933  auto &region = getRegion();
2934  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2935  Block &block = region.front();
2936  if (block.getNumArguments() != rank)
2937  return emitError("expected the block to have ") << rank << " arguments";
2938 
2939  // Note: the number and type of yield values are checked in the YieldOp.
2940  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2941  if (!en.value().isIndex())
2942  return emitOpError("expected block argument ")
2943  << (en.index() + 1) << " to be an index";
2944  }
2945 
2946  // Ensure that the region yields an element of the right type.
2947  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2948  if (yieldOp.getValue().getType() !=
2949  llvm::cast<ShapedType>(getType()).getElementType())
2950  return emitOpError("expected yield type to match shape element type");
2951 
2952  return success();
2953 }
2954 
2955 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2956  ArrayRef<int64_t> staticLow,
2957  ArrayRef<int64_t> staticHigh,
2958  ArrayRef<int64_t> resultShape) {
2959  unsigned rank = sourceType.getRank();
2960  if (staticLow.size() != rank)
2961  return RankedTensorType();
2962  if (staticHigh.size() != rank)
2963  return RankedTensorType();
2964  if (!resultShape.empty() && resultShape.size() != rank)
2965  return RankedTensorType();
2966 
2967  SmallVector<int64_t, 4> inferredShape;
2968  for (auto i : llvm::seq<unsigned>(0, rank)) {
2969  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2970  staticHigh[i] == ShapedType::kDynamic) {
2971  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2972  : resultShape[i]);
2973  } else {
2974  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2975  assert((resultShape.empty() || size == resultShape[i] ||
2976  resultShape[i] == ShapedType::kDynamic) &&
2977  "mismatch between inferred shape and result shape");
2978  inferredShape.push_back(size);
2979  }
2980  }
2981 
2982  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2983 }
2984 
2985 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2986  Value source, ArrayRef<int64_t> staticLow,
2987  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2988  bool nofold, ArrayRef<NamedAttribute> attrs) {
2989  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2990  if (!resultType)
2991  resultType = inferResultType(sourceType, staticLow, staticHigh);
2992  result.addAttributes(attrs);
2993  build(b, result, resultType, source, low, high,
2994  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2995  nofold ? b.getUnitAttr() : UnitAttr());
2996 }
2997 
2998 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2999  Value source, ValueRange low, ValueRange high, bool nofold,
3000  ArrayRef<NamedAttribute> attrs) {
3001  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3002  unsigned rank = sourceType.getRank();
3003  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3004  build(b, result, resultType, source, staticVector, staticVector, low, high,
3005  nofold, attrs);
3006 }
3007 
3008 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3009  Value source, ArrayRef<OpFoldResult> low,
3010  ArrayRef<OpFoldResult> high, bool nofold,
3011  ArrayRef<NamedAttribute> attrs) {
3012  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3013  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3014  SmallVector<int64_t, 4> staticLow, staticHigh;
3015  // staticLow and staticHigh have full information of the padding config.
3016  // This will grow staticLow and staticHigh with 1 value. If the config is
3017  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3018  // value as well.
3019  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3020  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3021  if (!resultType) {
3022  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3023  }
3024  assert(llvm::isa<RankedTensorType>(resultType));
3025  result.addAttributes(attrs);
3026  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3027  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3028  nofold ? b.getUnitAttr() : UnitAttr());
3029 }
3030 
3031 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3032  Value source, ArrayRef<OpFoldResult> low,
3033  ArrayRef<OpFoldResult> high, Value constantPadValue,
3034  bool nofold, ArrayRef<NamedAttribute> attrs) {
3035  build(b, result, resultType, source, low, high, nofold, attrs);
3036 
3037  // Add a region and a block to yield the pad value.
3038  Region *region = result.regions[0].get();
3039  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3040  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3041  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3042 
3043  // `builder.createBlock` changes the insertion point within the block. Create
3044  // a guard to reset the insertion point of the builder after it is destroyed.
3045  OpBuilder::InsertionGuard guard(b);
3046  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3047  b.create<tensor::YieldOp>(result.location, constantPadValue);
3048 }
3049 
3050 llvm::SmallBitVector PadOp::getPaddedDims() {
3051  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3052  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3053  for (const auto &en : enumerate(paddingWidths))
3054  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3055  paddedDims.set(en.index());
3056  };
3057  extractPaddedDims(getMixedLowPad());
3058  extractPaddedDims(getMixedHighPad());
3059  return paddedDims;
3060 }
3061 
3062 namespace {
3063 // Folds tensor.pad when padding is static zeros and the attribute
3064 // doesn't request otherwise.
3065 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3067 
3068  LogicalResult matchAndRewrite(PadOp padTensorOp,
3069  PatternRewriter &rewriter) const override {
3070  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3071  return failure();
3072  if (padTensorOp.getNofold())
3073  return failure();
3074  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3075  padTensorOp, padTensorOp.getResult().getType(),
3076  padTensorOp.getSource());
3077  return success();
3078  }
3079 };
3080 
3081 // Fold CastOp into PadOp when adding static information.
3082 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3084 
3085  LogicalResult matchAndRewrite(PadOp padTensorOp,
3086  PatternRewriter &rewriter) const override {
3087  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3088  if (!tensor::canFoldIntoConsumerOp(castOp))
3089  return failure();
3090 
3091  auto newResultType = PadOp::inferResultType(
3092  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3093  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3094  padTensorOp.getResultType().getShape());
3095 
3096  if (newResultType == padTensorOp.getResultType()) {
3097  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3098  padTensorOp.getSourceMutable().assign(castOp.getSource());
3099  });
3100  } else {
3101  auto newOp = rewriter.create<PadOp>(
3102  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3103  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3104  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3105  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3106  IRMapping mapper;
3107  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3108 
3109  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3110  padTensorOp, padTensorOp.getResultType(), newOp);
3111  }
3112  return success();
3113  }
3114 };
3115 
3116 // Fold CastOp using the result of PadOp back into the latter if it adds
3117 // static information.
3118 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3120 
3121  LogicalResult matchAndRewrite(PadOp padTensorOp,
3122  PatternRewriter &rewriter) const override {
3123  if (!padTensorOp.getResult().hasOneUse())
3124  return failure();
3125  auto tensorCastOp =
3126  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3127  if (!tensorCastOp)
3128  return failure();
3129  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3130  tensorCastOp.getDest().getType()))
3131  return failure();
3132 
3133  auto replacementOp = rewriter.create<PadOp>(
3134  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3135  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3136  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3137  padTensorOp.getHigh(), padTensorOp.getNofold(),
3138  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3139  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3140 
3141  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3142  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3143  return success();
3144  }
3145 };
3146 
3147 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3148 /// different dimensions. The pattern applies if the following preconditions
3149 /// hold:
3150 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3151 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3152 /// 3) the tensor::PadOps perform only high-padding,
3153 /// 4) the tensor::PadOps have the same constant padding value,
3154 /// 5) the tensor::PadOps do not have common padding dimensions,
3155 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3156 /// zero-offset for every dimension.
3157 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3158 /// the
3159 /// padded source dimensions.
3160 ///
3161 /// Example:
3162 ///
3163 /// ```mlir
3164 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3165 /// : tensor<64x64xf32> to tensor<?x64xf32>
3166 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3167 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3168 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3169 /// : tensor<8x64xf32> to tensor<8x?xf32>
3170 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3171 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3172 /// ```
3173 ///
3174 /// folds into:
3175 ///
3176 /// ```mlir
3177 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3178 /// : tensor<64x64xf32> to tensor<?x?xf32>
3179 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3180 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3181 /// ```
3182 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3184 
3185  LogicalResult matchAndRewrite(PadOp padOp,
3186  PatternRewriter &rewriter) const override {
3187  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3188  if (!innerSliceOp)
3189  return failure();
3190  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3191  if (!outerPadOp || outerPadOp.getNofold())
3192  return failure();
3193  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3194  if (!outerSliceOp)
3195  return failure();
3196 
3197  // 1) Fail if the chain is rank-reducing.
3198  int64_t rank = padOp.getSourceType().getRank();
3199  if (outerSliceOp.getSourceType().getRank() != rank) {
3200  return rewriter.notifyMatchFailure(padOp,
3201  "cannot fold rank-reducing chain");
3202  }
3203 
3204  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3205  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3206  return rewriter.notifyMatchFailure(
3207  padOp, "cannot fold non-unit stride ExtractSliceOps");
3208  }
3209 
3210  // 3) Fail if the tensor::PadOps have non-zero low padding.
3211  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3212  return rewriter.notifyMatchFailure(padOp,
3213  "cannot fold PadOps with low padding");
3214  }
3215 
3216  // 4) Fail if the tensor::PadOps padding values do not match.
3217  Attribute innerAttr, outerAttr;
3218  Value innerValue = padOp.getConstantPaddingValue();
3219  Value outerValue = outerPadOp.getConstantPaddingValue();
3220  if (!innerValue || !outerValue ||
3221  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3222  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3223  innerAttr != outerAttr) {
3224  return rewriter.notifyMatchFailure(
3225  padOp, "cannot fold PadOps with different padding values");
3226  }
3227 
3228  // 5) Fail if a dimension is padded by both tensor::PadOps.
3229  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3230  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3231  if (innerDims.anyCommon(outerDims)) {
3232  return rewriter.notifyMatchFailure(
3233  padOp, "cannot fold PadOps with common padding dimensions");
3234  }
3235 
3236  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3237  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3238  // for every dimension, and use the offset the other pair. Fail if no
3239  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3240  // exists.
3241  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3242  for (auto en : enumerate(newOffsets)) {
3243  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3244  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3245  if (!innerDims.test(en.index()) &&
3246  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3247  en.value() = outerOffset;
3248  continue;
3249  }
3250  if (!outerDims.test(en.index()) &&
3251  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3252  en.value() = innerOffset;
3253  continue;
3254  }
3255  return rewriter.notifyMatchFailure(
3256  padOp, "cannot find zero-offset and zero-padding pair");
3257  }
3258 
3259  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3260  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3261  // outer tensor::PadOp and fail if the size of the inner
3262  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3263  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3264  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3265  for (auto en : enumerate(newSizes)) {
3266  if (!outerDims.test(en.index()))
3267  continue;
3268  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3269  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3270  assert(!ShapedType::isDynamic(sourceSize) &&
3271  "expected padded dimension to have a static size");
3272  if (getConstantIntValue(sliceSize) != sourceSize) {
3273  return rewriter.notifyMatchFailure(
3274  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3275  "match the size of the outer padding");
3276  }
3277  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3278  }
3279 
3280  // Combine the high paddings of the two tensor::PadOps.
3281  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3282  for (auto en : enumerate(newHighPad)) {
3283  if (innerDims.test(en.index()))
3284  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3285  if (outerDims.test(en.index()))
3286  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3287  }
3288 
3289  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3290  // the two paddings in one step.
3291  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3292  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3293  innerSliceOp.getMixedStrides());
3294  auto newPadOp = rewriter.create<PadOp>(
3295  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3296  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3297  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3298  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3299  newPadOp.getRegion().begin());
3300  rewriter.replaceOp(padOp, newPadOp.getResult());
3301  return success();
3302  }
3303 };
3304 
3305 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3307 
3308  LogicalResult matchAndRewrite(PadOp padTensorOp,
3309  PatternRewriter &rewriter) const override {
3310  Value input = padTensorOp.getSource();
3311  if (!llvm::isa<RankedTensorType>(input.getType()))
3312  return failure();
3313  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3314  auto inputRank = inputDims.size();
3315 
3316  auto oldResultType =
3317  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3318  if (!oldResultType)
3319  return failure();
3320 
3321  auto outputDims = oldResultType.getShape();
3322 
3323  // Extract the static info from the high and low operands.
3324  SmallVector<int64_t> constOperandsLow;
3325  SmallVector<Value> newLows;
3326  for (auto operand : padTensorOp.getLow()) {
3327  APSInt intOp;
3328  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3329  constOperandsLow.push_back(ShapedType::kDynamic);
3330  newLows.push_back(operand);
3331  continue;
3332  }
3333  constOperandsLow.push_back(intOp.getExtValue());
3334  }
3335  SmallVector<int64_t> constOperandsHigh;
3336  SmallVector<Value> newHighs;
3337  for (auto operand : padTensorOp.getHigh()) {
3338  APSInt intOp;
3339  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3340  constOperandsHigh.push_back(ShapedType::kDynamic);
3341  newHighs.push_back(operand);
3342  continue;
3343  }
3344  constOperandsHigh.push_back(intOp.getExtValue());
3345  }
3346 
3347  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3348  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3349 
3350  // Verify the op is well-formed.
3351  if (inputDims.size() != outputDims.size() ||
3352  inputDims.size() != constLow.size() ||
3353  inputDims.size() != constHigh.size())
3354  return failure();
3355 
3356  auto lowCount = 0;
3357  auto highCount = 0;
3358  for (size_t i = 0; i < inputRank; i++) {
3359  if (constLow[i] == ShapedType::kDynamic)
3360  constLow[i] = constOperandsLow[lowCount++];
3361  if (constHigh[i] == ShapedType::kDynamic)
3362  constHigh[i] = constOperandsHigh[highCount++];
3363  }
3364 
3365  auto staticLow = ArrayRef<int64_t>(constLow);
3366  auto staticHigh = ArrayRef<int64_t>(constHigh);
3367 
3368  // Calculate the output sizes with the static information.
3369  SmallVector<int64_t> newOutDims;
3370  for (size_t i = 0; i < inputRank; i++) {
3371  if (outputDims[i] == ShapedType::kDynamic) {
3372  newOutDims.push_back(
3373  (staticLow[i] == ShapedType::kDynamic ||
3374  staticHigh[i] == ShapedType::kDynamic ||
3375  inputDims[i] == ShapedType::kDynamic
3376  ? ShapedType::kDynamic
3377  : inputDims[i] + staticLow[i] + staticHigh[i]));
3378  } else {
3379  newOutDims.push_back(outputDims[i]);
3380  }
3381  }
3382 
3383  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3384  llvm::all_of(newOutDims,
3385  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3386  return failure();
3387 
3388  // Rewrite the op using the new static type.
3389  auto newResultType = RankedTensorType::get(
3390  newOutDims, padTensorOp.getType().getElementType());
3391  auto newOp = rewriter.create<PadOp>(
3392  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3393  newLows, newHighs, padTensorOp.getNofold(),
3394  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3395 
3396  IRMapping mapper;
3397  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3398  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3399  newOp);
3400 
3401  return success();
3402  }
3403 };
3404 
3405 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3406 ///
3407 /// Example:
3408 ///
3409 /// ```mlir
3410 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3411 /// tensor.yield %val
3412 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3413 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3414 /// tensor.yield %val
3415 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3416 /// ```
3417 ///
3418 /// folds into:
3419 ///
3420 /// ```mlir
3421 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3422 /// tensor.yield %val
3423 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3424 /// ```
3425 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3427 
3428  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3429  PatternRewriter &rewriter) const override {
3430  if (padOp.getNofold()) {
3431  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3432  }
3433 
3434  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3435  if (!producerPad || producerPad.getNofold()) {
3436  return rewriter.notifyMatchFailure(
3437  padOp, "producer is not a foldable tensor.pad op");
3438  }
3439 
3440  // Fail if the tensor::PadOps padding values do not match.
3441  Value consumerPadValue = padOp.getConstantPaddingValue();
3442  Value producerPadValue = producerPad.getConstantPaddingValue();
3443  if (!consumerPadValue || !producerPadValue ||
3444  consumerPadValue != producerPadValue) {
3445  return rewriter.notifyMatchFailure(
3446  padOp,
3447  "cannot fold PadOps with different or non-constant padding values");
3448  }
3449 
3450  Location loc = padOp.getLoc();
3451  AffineExpr d0, d1;
3452  bindDims(rewriter.getContext(), d0, d1);
3453 
3454  // Combine the low/high paddings of the two tensor::PadOps.
3455  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3456  ArrayRef<OpFoldResult> producerPaddings) {
3457  SmallVector<OpFoldResult> sumPaddings;
3458  for (auto [consumerIndex, producerIndex] :
3459  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3460  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3461  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3462  }
3463  return sumPaddings;
3464  };
3465 
3466  SmallVector<OpFoldResult> newHighPad =
3467  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3468  SmallVector<OpFoldResult> newLowPad =
3469  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3470 
3471  auto newPadOp = rewriter.create<tensor::PadOp>(
3472  padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3473  newLowPad, newHighPad, padOp.getNofold(),
3474  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3475  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3476  newPadOp.getRegion().begin());
3477  rewriter.replaceOp(padOp, newPadOp.getResult());
3478  return success();
3479  }
3480 };
3481 
3482 } // namespace
3483 
3484 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3485  MLIRContext *context) {
3486  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3487  FoldOrthogonalPaddings, FoldStaticPadding,
3488  FoldConsecutiveConstantPadding>(context);
3489 }
3490 
3491 /// Return the padding value of the PadOp if it constant. In this context,
3492 /// "constant" means an actual constant or "defined outside of the block".
3493 ///
3494 /// Values are considered constant in three cases:
3495 /// - A ConstantLike value.
3496 /// - A basic block argument from a different block.
3497 /// - A value defined outside of the block.
3498 ///
3499 /// If the padding value is not constant, an empty Value is returned.
3500 Value PadOp::getConstantPaddingValue() {
3501  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3502  if (!yieldOp)
3503  return {};
3504  Value padValue = yieldOp.getValue();
3505  // Check if yield value is a constant.
3506  if (matchPattern(padValue, m_Constant()))
3507  return padValue;
3508  // Check if yield value is defined inside the PadOp block.
3509  if (padValue.getParentBlock() == &getRegion().front())
3510  return {};
3511  // Else: Yield value defined outside of the PadOp block.
3512  return padValue;
3513 }
3514 
3515 OpFoldResult PadOp::fold(FoldAdaptor) {
3516  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3517  !getNofold())
3518  return getSource();
3519  return {};
3520 }
3521 
3522 //===----------------------------------------------------------------------===//
3523 // ParallelInsertSliceOp
3524 //===----------------------------------------------------------------------===//
3525 
3526 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3527  ParallelCombiningOpInterface parallelCombiningParent =
3528  getParallelCombiningParent();
3529  for (const auto &it :
3530  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3531  Operation &nextOp = it.value();
3532  if (&nextOp == getOperation())
3533  return parallelCombiningParent.getParentResult(it.index());
3534  }
3535  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3536 }
3537 
3538 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3539 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3540  Value source, Value dest,
3541  ArrayRef<OpFoldResult> offsets,
3542  ArrayRef<OpFoldResult> sizes,
3543  ArrayRef<OpFoldResult> strides,
3544  ArrayRef<NamedAttribute> attrs) {
3545  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3546  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3547  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3548  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3549  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3550  result.addAttributes(attrs);
3551  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3552  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3553  b.getDenseI64ArrayAttr(staticSizes),
3554  b.getDenseI64ArrayAttr(staticStrides));
3555 }
3556 
3557 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3558 /// packed into a Range vector.
3559 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3560  Value source, Value dest,
3561  ArrayRef<Range> ranges,
3562  ArrayRef<NamedAttribute> attrs) {
3563  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3564  build(b, result, source, dest, offsets, sizes, strides, attrs);
3565 }
3566 
3567 // Build a ParallelInsertSliceOp with dynamic entries.
3568 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3569  Value source, Value dest, ValueRange offsets,
3570  ValueRange sizes, ValueRange strides,
3571  ArrayRef<NamedAttribute> attrs) {
3572  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3573  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3574  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3575  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3576  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3577  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3578  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3579 }
3580 
3581 LogicalResult ParallelInsertSliceOp::verify() {
3582  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3583  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3584  << *(getOperation()->getParentOp());
3585 
3586  RankedTensorType expectedType;
3587  SliceVerificationResult result =
3588  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3589  getStaticSizes(), getStaticStrides(), &expectedType);
3590  return produceSliceErrorMsg(result, *this, expectedType);
3591 }
3592 
3593 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3594  RewritePatternSet &results, MLIRContext *context) {
3595  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3596  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3597  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3598 }
3599 
3600 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3601  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3602 }
3603 
3604 //===----------------------------------------------------------------------===//
3605 // ScatterOp
3606 //===----------------------------------------------------------------------===//
3607 
3608 void ScatterOp::getAsmResultNames(
3609  function_ref<void(Value, StringRef)> setNameFn) {
3610  setNameFn(getResult(), "scatter");
3611 }
3612 
3613 LogicalResult ScatterOp::verify() {
3614  int64_t destRank = getDestType().getRank();
3615  ArrayRef<int64_t> scatterDims = getScatterDims();
3616  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3617  getIndicesType().getShape(), destRank,
3618  "scatter", "dest")))
3619  return failure();
3620 
3621  if (!getUnique())
3622  return emitOpError("requires 'unique' attribute to be set");
3623  // TODO: we could also check statically that there are fewer leading index
3624  // tensor dims than the dest dims. If this is not the case, the unique
3625  // attribute cannot be true.
3626 
3627  // Use the GatherOp::inferResultType on the `dest` type and verify the
3628  // expected type matches the source type.
3629  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3630  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3631  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3632  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3633  if (getSourceType() != expectedSourceType &&
3634  getSourceType() != expectedRankReducedSourceType) {
3635  return emitOpError("source type "
3636  "mismatch: "
3637  "expected ")
3638  << expectedSourceType << " or its rank-reduced variant "
3639  << expectedRankReducedSourceType << " (got: " << getSourceType()
3640  << ")";
3641  }
3642 
3643  return success();
3644 }
3645 
3646 //===----------------------------------------------------------------------===//
3647 // SplatOp
3648 //===----------------------------------------------------------------------===//
3649 
3650 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3651  Type aggregateType, ValueRange dynamicSizes) {
3652  build(builder, result, aggregateType, element, dynamicSizes);
3653 }
3654 
3655 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3656  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3657  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3658  build(builder, result, aggregateType, element, dynamicSizes);
3659 }
3660 
3661 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3662  ArrayRef<OpFoldResult> sizes) {
3663  SmallVector<int64_t> staticShape;
3664  SmallVector<Value> dynamicSizes;
3665  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3666  build(builder, result, element, staticShape, dynamicSizes);
3667 }
3668 
3669 void SplatOp::getAsmResultNames(
3670  function_ref<void(Value, StringRef)> setNameFn) {
3671  setNameFn(getResult(), "splat");
3672 }
3673 
3674 LogicalResult SplatOp::verify() {
3675  if (getType().getNumDynamicDims() !=
3676  static_cast<int64_t>(getDynamicSizes().size()))
3677  return emitOpError("incorrect number of dynamic sizes, has ")
3678  << getDynamicSizes().size() << ", expected "
3679  << getType().getNumDynamicDims();
3680  return success();
3681 }
3682 
3683 LogicalResult
3685  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3686  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3687  unsigned ctr = 0;
3688  for (int64_t i = 0; i < getType().getRank(); ++i) {
3689  if (getType().isDynamicDim(i)) {
3690  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3691  } else {
3692  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3693  }
3694  }
3695  return success();
3696 }
3697 
3698 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3699  auto constOperand = adaptor.getInput();
3700  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3701  return {};
3702 
3703  // Do not fold if the splat is not statically shaped
3704  if (!getType().hasStaticShape())
3705  return {};
3706 
3707  // SplatElementsAttr::get treats single value for second arg as being a
3708  // splat.
3709  return SplatElementsAttr::get(getType(), {constOperand});
3710 }
3711 
3712 //===----------------------------------------------------------------------===//
3713 // PackOp/UnPackOp Common
3714 //===----------------------------------------------------------------------===//
3715 
3716 template <typename OpTy>
3717 static LogicalResult
3719  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3720  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3721  "applies to only pack or unpack operations");
3722  int64_t destRank = op.getDestRank();
3723  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3724  reifiedReturnShapes[0] =
3725  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3726  return success();
3727 }
3728 
3729 template <typename OpTy>
3731  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3732  "applies to only pack or unpack operations");
3733  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3734  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3735  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3736  assert(tiles.size() == dimsToTile.size() &&
3737  "tiles must match indices of dimension to block");
3738  // bind the dimension `i` with the tile factor.
3739  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3740  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3741  return dimAndTileMapping;
3742 }
3743 
3744 template <typename OpTy>
3746  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3747  "applies to only pack or unpack operations");
3748  Builder builder(op);
3749  SmallVector<OpFoldResult> mixedInnerTiles;
3750  unsigned dynamicValIndex = 0;
3751  for (int64_t staticTile : op.getStaticInnerTiles()) {
3752  if (!ShapedType::isDynamic(staticTile))
3753  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3754  else
3755  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3756  }
3757  return mixedInnerTiles;
3758 }
3759 
3760 template <typename OpTy>
3762  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3763  "applies to only pack or unpack operations");
3764  SmallVector<Value> dynamicTiles;
3765  SmallVector<int64_t> staticTiles;
3766  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3767  return staticTiles;
3768 }
3769 
3770 /// Returns true if `dimsPos` is invalid. It is invalid when:
3771 /// a) It contains duplicate.
3772 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3773 /// c) The number of elements in `dimsPos` is > than `rank`.
3775  size_t rank) {
3776  size_t dimsPosSize = dimsPos.size();
3777  if (dimsPosSize > rank)
3778  return true;
3779  DenseSet<int64_t> uniqued;
3780  for (int64_t dim : dimsPos)
3781  uniqued.insert(dim);
3782  if (dimsPosSize != uniqued.size())
3783  return true;
3784  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3785  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3786  });
3787 }
3788 
3789 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3790 /// of the `limitShape`.
3791 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3792  ArrayRef<int64_t> limitShape) {
3793  assert(
3794  sourceShape.size() == limitShape.size() &&
3795  "expected source shape rank, and limit of the shape to have same rank");
3796  return llvm::all_of(
3797  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3798  int64_t sourceExtent = std::get<0>(it);
3799  int64_t limit = std::get<1>(it);
3800  return ShapedType::isDynamic(sourceExtent) ||
3801  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3802  });
3803 }
3804 
3805 template <typename OpTy>
3806 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
3807  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3808  "applies to only pack or unpack operations");
3809  Operation *op = packOrUnPack.getOperation();
3810 
3811  // Return true if we have a zero-value tile.
3812  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3813  return llvm::any_of(
3814  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3815  };
3816 
3817  // Verify tiles. Do not allow zero tiles.
3818  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3819  if (hasZeros(mixedTiles))
3820  return op->emitError("invalid zero tile factor");
3821 
3822  // Verify inner_dims_pos and outer_dims_perm.
3823  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3824  ? packOrUnPack.getSourceType()
3825  : packOrUnPack.getDestType();
3826  size_t unpackedRank = unpackedType.getRank();
3827  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3828  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3829  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3830  return op->emitError("invalid inner_dims_pos vector");
3831  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3832  return op->emitError("invalid outer_dims_perm vector");
3833  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3834  return op->emitError("outer_dims_perm must be a permutation or empty");
3835 
3836  // Tiling factors must be less than or equal to the input rank for pack (or
3837  // output rank for unpack), and must match the number of `inner_dims_pos`.
3838  if (mixedTiles.size() > unpackedRank) {
3839  return op->emitError("tiling factors must be less than or equal to the "
3840  "input rank for pack or output rank for unpack");
3841  }
3842  if (mixedTiles.size() != innerDimsPos.size()) {
3843  return op->emitError(
3844  "tiling factors must equal the number of dimensions to tile");
3845  }
3846 
3847  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3848  ? packOrUnPack.getDestType()
3849  : packOrUnPack.getSourceType();
3850  size_t packedRank = packedType.getRank();
3851  // Require output rank to match input rank + number of blocking factors.
3852  if (unpackedRank + mixedTiles.size() != packedRank) {
3853  return op->emitError(
3854  "packed rank must equal unpacked rank + tiling factors");
3855  }
3856 
3857  // Verify result shape is greater than the minimum expected
3858  // by the pack operation, and that the output shape
3859  // represents full tiles.
3860  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3861  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3862  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3863  return op->emitError("the shape of output is not large enough to hold the "
3864  "packed data. Expected at least ")
3865  << expectedPackedType << ", got " << packedType;
3866  }
3867  if (!llvm::all_of(
3868  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3869  mixedTiles),
3870  [](std::tuple<int64_t, OpFoldResult> it) {
3871  std::optional<int64_t> constTileSize =
3872  getConstantIntValue(std::get<1>(it));
3873  int64_t shape = std::get<0>(it);
3874  if (!constTileSize) {
3875  // If specified tile size is dynamic, output shape should
3876  // be dynamic too.
3877  return ShapedType::isDynamic(shape);
3878  }
3879  if (ShapedType::isDynamic(shape)) {
3880  // For the shape being dynamic when tile size is
3881  // specified, return true. In canonical form a constant
3882  // tile size should lead to constant shape of the tiled
3883  // dimension, but not needed for verification.
3884  return true;
3885  }
3886  return shape == constTileSize.value();
3887  })) {
3888  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3889  "tiled dimension in the packed type");
3890  }
3891  return success();
3892 }
3893 
3894 namespace {
3895 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3896 /// various permutations to the op.
3897 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3898 // these. These may or may not become true foldings / canonicalizations
3899 // depending on how aggressive we want to be in automatically folding
3900 // transposes.
3901 struct PackOrUnPackTransposeResult {
3902  SmallVector<int64_t> innerDimsPos;
3903  SmallVector<OpFoldResult> innerTiles;
3904  SmallVector<int64_t> outerDimsPerm;
3905 };
3906 } // namespace
3907 
3908 template <typename OpTy>
3909 static PackOrUnPackTransposeResult
3911  ArrayRef<int64_t> innerPermutation,
3912  ArrayRef<int64_t> outerPermutation) {
3913  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3914  "applies to only pack or unpack operations");
3915  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3916  "some permutation must be non-empty");
3917  PackOrUnPackTransposeResult metadata;
3918  metadata.innerDimsPos =
3919  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3920  metadata.innerTiles =
3921  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3922  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3923  ? packOrUnPackOp.getSourceRank()
3924  : packOrUnPackOp.getDestRank();
3925  metadata.outerDimsPerm =
3926  packOrUnPackOp.getOuterDimsPerm().empty()
3927  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3928  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3929  if (!innerPermutation.empty()) {
3930  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3931  isPermutationVector(innerPermutation) &&
3932  "invalid inner permutation");
3933  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3934  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3935  }
3936  if (!outerPermutation.empty()) {
3937  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3938  isPermutationVector(outerPermutation) &&
3939  "invalid outer permutation");
3940  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3941  }
3942  return metadata;
3943 }
3944 
3945 //===----------------------------------------------------------------------===//
3946 // PackOp
3947 //===----------------------------------------------------------------------===//
3948 
3949 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3950  setNameFn(getResult(), "pack");
3951 }
3952 
3953 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3954  Value dest, ArrayRef<int64_t> innerDimsPos,
3955  ArrayRef<OpFoldResult> innerTiles,
3956  std::optional<Value> paddingValue,
3957  ArrayRef<int64_t> outerDimsPerm) {
3958  assert(innerDimsPos.size() == innerTiles.size() &&
3959  "number of tile sizes specified must match the specified number of "
3960  "original dimensions to be tiled");
3961  SmallVector<int64_t> staticTileSizes;
3962  SmallVector<Value> dynamicTileSizes;
3963  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3964  build(builder, state, dest.getType(), source, dest,
3965  paddingValue ? *paddingValue : nullptr,
3966  outerDimsPerm.empty() ? nullptr
3967  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3968  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3969  builder.getDenseI64ArrayAttr(staticTileSizes));
3970 }
3971 
3972 LogicalResult
3974  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3975  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3976 }
3977 
3978 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3979  return getDimAndTileMappingImpl(*this);
3980 }
3981 
3982 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3983  return getMixedTilesImpl(*this);
3984 }
3985 
3986 SmallVector<int64_t> PackOp::getStaticTiles() {
3987  return getStaticTilesImpl(*this);
3988 }
3989 
3990 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3991  ArrayRef<int64_t> innerDimsPos,
3992  ArrayRef<int64_t> outputShape,
3993  ArrayRef<int64_t> outerDimsPerm,
3994  ArrayRef<OpFoldResult> innerTiles) {
3995  SmallVector<int64_t> outputTileSizes(
3996  outputShape.take_front(inputShape.size()));
3997  if (!outerDimsPerm.empty()) {
3998  assert(outerDimsPerm.size() == outputTileSizes.size() &&
3999  "expected output and outer_dims_perm to have same size");
4000  applyPermutationToVector(outputTileSizes,
4001  invertPermutationVector(outerDimsPerm));
4002  }
4003  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4004  if (ShapedType::isDynamic(inputShape[pos]))
4005  continue;
4006  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4007 
4008  if (!constantTile) {
4009  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4010  (inputShape[pos] % outputTileSizes[pos] != 0))
4011  return true;
4012  } else if (inputShape[pos] % (*constantTile) != 0) {
4013  return true;
4014  }
4015  }
4016  return false;
4017 }
4018 
4019 LogicalResult PackOp::verify() {
4020  if (failed(commonVerifierPackAndUnPackOp(*this)))
4021  return failure();
4022 
4023  // Verify padding value, and bail out if the tile does not divide the
4024  // dimension fully. In the case of dynamic tile factors or dimensions, having
4025  // a partial tile is undefined behavior.
4026  auto paddingValue = getPaddingValue();
4027  if (paddingValue &&
4028  paddingValue.getType() != getSourceType().getElementType()) {
4029  return emitOpError("expected padding_value has ")
4030  << getSourceType().getElementType()
4031  << " but got: " << paddingValue.getType();
4032  }
4033 
4034  if (!paddingValue &&
4035  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4036  getDestType().getShape(), getOuterDimsPerm(),
4037  getMixedTiles())) {
4038  return emitOpError(
4039  "invalid tile factor or output size provided. Only full tiles are "
4040  "supported when padding_value is not set");
4041  }
4042  return success();
4043 }
4044 
4045 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4046 /// Value's to kDynamic, even if they are arith.constant values.
4047 static SmallVector<int64_t>
4049  SmallVector<int64_t> result;
4050  for (auto o : ofrs) {
4051  // Have to do this first, as getConstantIntValue special-cases constants.
4052  if (llvm::dyn_cast_if_present<Value>(o))
4053  result.push_back(ShapedType::kDynamic);
4054  else
4055  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4056  }
4057  return result;
4058 }
4059 
4060 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4061 /// the packed type. Having a shared helper helps implement these two methods in
4062 /// a way that ensures that they agree on which dimensions are dynamic.
4064  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4065  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4066  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4067  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4068  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4069  continue;
4070  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4071  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4072  continue;
4073  }
4074  resultShape[tiledDim.value()] = divideCeilSigned(
4075  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4076  }
4077 
4078  // Swap tile loops if outer_dims_perm is available.
4079  if (!outerDimsPerm.empty())
4080  applyPermutationToVector(resultShape, outerDimsPerm);
4081 
4082  // Append the inner tile dimensions.
4083  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4084  return resultShape;
4085 }
4086 
4087 SmallVector<OpFoldResult> PackOp::getResultShape(
4088  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4089  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4090  ArrayRef<int64_t> outerDimsPerm) {
4091  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4092 
4093  AffineExpr s0, s1;
4094  bindSymbols(builder.getContext(), s0, s1);
4095  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4096  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4097  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4098  builder, loc, ceilDivExpr,
4099  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4100  }
4101  if (!outerDimsPerm.empty())
4102  applyPermutationToVector(resultDims, outerDimsPerm);
4103  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4104 
4105  SmallVector<int64_t> resultTypeShape =
4107  asShapeWithAnyValueAsDynamic(innerTileSizes),
4108  innerDimsPos, outerDimsPerm);
4109 
4110  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4111  // result type shape says it's a dynamic dim. This is needed as callers may
4112  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4113  // dynamic dims returned by that.
4114  for (unsigned i = 0; i < resultDims.size(); ++i) {
4115  if (!ShapedType::isDynamic(resultTypeShape[i]))
4116  continue;
4117  resultDims[i] =
4118  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4119  }
4120 
4121  return resultDims;
4122 }
4123 
4124 /// Get the expected packed type based on source type, tile factors, position of
4125 /// the inner tiles and permutation of the outer tiled loop.
4126 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4127  ArrayRef<int64_t> innerTileSizes,
4128  ArrayRef<int64_t> innerDimsPos,
4129  ArrayRef<int64_t> outerDimsPerm) {
4131  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4132  return RankedTensorType::get(resultShape, sourceType.getElementType());
4133 }
4134 
4135 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4136  ArrayRef<OpFoldResult> innerTileSizes,
4137  ArrayRef<int64_t> innerDimsPos,
4138  ArrayRef<int64_t> outerDimsPerm) {
4139  AffineExpr dim0, dim1;
4140  bindDims(b.getContext(), dim0, dim1);
4141  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4142  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4143  {v1, v2});
4144  };
4145 
4146  SmallVector<OpFoldResult> mixedSizes;
4147  for (auto [index, value] : llvm::enumerate(
4148  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4149  if (ShapedType::isDynamic(value))
4150  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4151  else
4152  mixedSizes.push_back(b.getIndexAttr(value));
4153  }
4154  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4155  int64_t dimPos = std::get<0>(it);
4156  OpFoldResult tileSize = std::get<1>(it);
4157  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4158  }
4159  if (!outerDimsPerm.empty())
4160  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4161 
4162  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4163  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4164  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4165 }
4166 
4167 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4168  ArrayRef<int64_t> innerPermutation,
4169  ArrayRef<int64_t> outerPermutation) {
4170  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4171  *this, innerPermutation, outerPermutation);
4172  Value transposedDest =
4173  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4174  metadata.innerDimsPos, metadata.outerDimsPerm);
4175  return b.create<PackOp>(loc, getSource(), transposedDest,
4176  metadata.innerDimsPos, metadata.innerTiles,
4177  getPaddingValue(), metadata.outerDimsPerm);
4178 }
4179 
4180 /// Returns true if the tiles and the tiled dims are constant.
4181 template <typename OpTy>
4183  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4184  "applies to only pack or unpack operations");
4185  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4186  ? op.getDestType()
4187  : op.getSourceType();
4188  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4189  for (auto [dimDest, tile] : llvm::zip(
4190  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4191  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4192  if (!constTileSize || ShapedType::isDynamic(dimDest))
4193  return false;
4194  }
4195  return true;
4196 }
4197 
4198 Speculation::Speculatability PackOp::getSpeculatability() {
4199  if (getPaddingValue())
4201 
4202  // The verifier rejects already operations if we can statically prove that the
4203  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4204  // to have constant tiles and tiled inner dimensions.
4205  if (!areTilesAndTiledDimsAllConstant(*this))
4207 
4209 }
4210 
4211 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4212 // dimensions for pack and unpack.
4213 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4214  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4215  return false;
4216  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4217  return true;
4218  // Outer dims permutation is optional.
4219  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4220  // identity permutation.
4221  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4222  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4223 }
4224 
4225 // Return true if pack and unpack have the same tiles.
4226 // Same SSA values or same integer constants.
4227 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4228  auto packTiles = packOp.getMixedTiles();
4229  auto unPackTiles = unPackOp.getMixedTiles();
4230  if (packTiles.size() != unPackTiles.size())
4231  return false;
4232  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4233  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4234  return false;
4235  }
4236  return true;
4237 }
4238 
4239 /// Returns true if the pack op does not need a padding value.
4240 static bool paddingIsNotNeeded(PackOp op) {
4241  auto srcType = op.getSourceType();
4242  if (llvm::any_of(op.getInnerDimsPos(),
4243  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4244  return false;
4245  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4246  return false;
4247  return !PackOp::requirePaddingValue(
4248  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4249  op.getOuterDimsPerm(), op.getMixedTiles());
4250 }
4251 
4252 /// Returns true if the `srcShape` or `destShape` is different from the one in
4253 /// `packOp` and populates each with the inferred static shape.
4254 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4255  SmallVectorImpl<int64_t> &destShape) {
4256  bool changeNeeded = false;
4257  srcShape.assign(packOp.getSourceType().getShape().begin(),
4258  packOp.getSourceType().getShape().end());
4259  destShape.assign(packOp.getDestType().getShape().begin(),
4260  packOp.getDestType().getShape().end());
4261  llvm::SmallSetVector<int64_t, 4> innerDims;
4262  innerDims.insert(packOp.getInnerDimsPos().begin(),
4263  packOp.getInnerDimsPos().end());
4264  SmallVector<int64_t> inverseOuterDimsPerm;
4265  if (!packOp.getOuterDimsPerm().empty())
4266  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4267  int srcRank = packOp.getSourceRank();
4268  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4269  if (innerDims.contains(i))
4270  continue;
4271  int64_t srcPos = i;
4272  int64_t destPos = i;
4273  if (!inverseOuterDimsPerm.empty())
4274  destPos = inverseOuterDimsPerm[srcPos];
4275  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4276  ShapedType::isDynamic(destShape[destPos])) {
4277  continue;
4278  }
4279  int64_t size = srcShape[srcPos];
4280  if (ShapedType::isDynamic(size))
4281  size = destShape[destPos];
4282  srcShape[srcPos] = size;
4283  destShape[destPos] = size;
4284  changeNeeded = true;
4285  }
4286  return changeNeeded;
4287 }
4288 
4289 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4290  // Fold an pack(unpack(x)) to x.
4291  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4292  if (unPackOp.getSourceType() != packOp.getDestType())
4293  return failure();
4294  if (packOp.getPaddingValue() ||
4295  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4296  !haveSameTiles(packOp, unPackOp))
4297  return failure();
4298  rewriter.replaceOp(packOp, unPackOp.getSource());
4299  return success();
4300  }
4301 
4302  // Fold optional PaddingValue operand away if padding is not needed.
4303  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4304  rewriter.startOpModification(packOp);
4305  packOp.getPaddingValueMutable().clear();
4306  rewriter.finalizeOpModification(packOp);
4307  return success();
4308  }
4309 
4310  // Insert tensor.cast ops if static shape inference is available..
4311  SmallVector<int64_t> srcShape, destShape;
4312  if (inferStaticShape(packOp, srcShape, destShape)) {
4313  Location loc = packOp.getLoc();
4314  Value source = packOp.getSource();
4315  if (srcShape != packOp.getSourceType().getShape()) {
4316  auto newSrcType = packOp.getSourceType().clone(srcShape);
4317  source =
4318  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4319  }
4320  Value dest = packOp.getDest();
4321  if (destShape != packOp.getDestType().getShape()) {
4322  auto newDestType = packOp.getDestType().clone(destShape);
4323  dest =
4324  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4325  }
4326  Value newOp = rewriter.create<tensor::PackOp>(
4327  loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4328  packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4329  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4330  packOp, packOp.getResult().getType(), newOp);
4331  return success();
4332  }
4333 
4334  return failure();
4335 }
4336 
4337 template <typename PackOrUnpackOp>
4338 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4339  RankedTensorType packedTensorType) {
4340  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4341  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4342  "Function meant for pack/unpack");
4343  // This is a pad if packing only adds ones and we don't transpose dimensions.
4344 
4345  // Check that we are not transposing any dimensions.
4346  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4347  int64_t numPackedDims = innerDimsPos.size();
4348  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4349  if (orderedDims != innerDimsPos) {
4350  // Dimensions don't happen in order.
4351  return false;
4352  }
4353 
4354  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4355  int64_t packedRank = packedTensorType.getRank();
4356  // At this point we know that we are taking numPackedDims outer
4357  // dimensions and pushing them all the way as the inner most dimensions.
4358  // What's left on the outer most dimensions is, in this order:
4359  // - the factor of the packed dimensions, then
4360  // - the untouched dimensions
4361  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4362  // if all the dimensions that bubble outerward are ones.
4363  // Therefore check that all the dimensions but the numPackedDims inner most
4364  // ones are ones.
4365  return llvm::all_of(
4366  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4367  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4368 }
4369 
4370 bool PackOp::isLikePad() {
4371  auto packedTensorType =
4372  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4373  return isLikePadUnPad(*this, packedTensorType);
4374 }
4375 
4376 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4377  std::optional<Attribute> paddingValue;
4378  if (auto pad = adaptor.getPaddingValue())
4379  paddingValue = pad;
4380  if (OpFoldResult reshapedSource = reshapeConstantSource(
4381  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4382  getDestType(), paddingValue))
4383  return reshapedSource;
4384  return {};
4385 }
4386 
4387 //===----------------------------------------------------------------------===//
4388 // UnPackOp
4389 //===----------------------------------------------------------------------===//
4390 
4391 void UnPackOp::getAsmResultNames(
4392  function_ref<void(Value, StringRef)> setNameFn) {
4393  setNameFn(getResult(), "unpack");
4394 }
4395 
4396 LogicalResult
4398  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4399  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4400 }
4401 
4402 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4403  return getDimAndTileMappingImpl(*this);
4404 }
4405 
4406 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4407  return getMixedTilesImpl(*this);
4408 }
4409 
4410 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4411  return getStaticTilesImpl(*this);
4412 }
4413 
4414 LogicalResult UnPackOp::verify() {
4415  return commonVerifierPackAndUnPackOp(*this);
4416 }
4417 
4418 Speculation::Speculatability UnPackOp::getSpeculatability() {
4419  // See PackOp::getSpeculatability.
4420  if (!areTilesAndTiledDimsAllConstant(*this))
4422 
4424 }
4425 
4426 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4427  Value dest, ArrayRef<int64_t> innerDimsPos,
4428  ArrayRef<OpFoldResult> innerTiles,
4429  ArrayRef<int64_t> outerDimsPerm) {
4430  assert(innerDimsPos.size() == innerTiles.size() &&
4431  "number of tile sizes specified must match the specified number of "
4432  "original dimensions to be tiled");
4433  SmallVector<int64_t> staticTileSizes;
4434  SmallVector<Value> dynamicTileSizes;
4435  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4436  build(builder, state, dest.getType(), source, dest,
4437  outerDimsPerm.empty() ? nullptr
4438  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4439  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4440  builder.getDenseI64ArrayAttr(staticTileSizes));
4441 }
4442 
4443 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4444  Value source,
4445  ArrayRef<OpFoldResult> innerTileSizes,
4446  ArrayRef<int64_t> innerDimsPos,
4447  ArrayRef<int64_t> outerDimsPerm) {
4448  AffineExpr sym0, sym1;
4449  bindSymbols(b.getContext(), sym0, sym1);
4450  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4451  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4452  };
4453 
4454  SmallVector<OpFoldResult> mixedSizes;
4455  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4456  for (auto i :
4457  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4458  if (srcType.isDynamicDim(i))
4459  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4460  else
4461  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4462  }
4463  if (!outerDimsPerm.empty()) {
4464  applyPermutationToVector<OpFoldResult>(
4465  mixedSizes, invertPermutationVector(outerDimsPerm));
4466  }
4467 
4468  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4469  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4470 
4471  auto elemType = srcType.getElementType();
4472  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4473 }
4474 
4475 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4476  Value transposedSource,
4477  ArrayRef<int64_t> innerPermutation,
4478  ArrayRef<int64_t> outerPermutation) {
4479  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4480  *this, innerPermutation, outerPermutation);
4481  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4482  metadata.innerDimsPos, metadata.innerTiles,
4483  metadata.outerDimsPerm);
4484 }
4485 
4486 /// Returns true if the `srcShape` or `destShape` is different from the one in
4487 /// `op` and populates each with the inferred static shape.
4488 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4489  SmallVectorImpl<int64_t> &destShape) {
4490  bool changeNeeded = false;
4491  srcShape.assign(op.getSourceType().getShape().begin(),
4492  op.getSourceType().getShape().end());
4493  destShape.assign(op.getDestType().getShape().begin(),
4494  op.getDestType().getShape().end());
4495  llvm::SmallSetVector<int64_t, 4> innerDims;
4496  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4497  SmallVector<int64_t> inverseOuterDimsPerm;
4498  if (!op.getOuterDimsPerm().empty())
4499  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4500  int destRank = op.getDestRank();
4501  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4502  if (innerDims.contains(i))
4503  continue;
4504  int64_t srcPos = i;
4505  int64_t destPos = i;
4506  if (!inverseOuterDimsPerm.empty())
4507  srcPos = inverseOuterDimsPerm[destPos];
4508  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4509  ShapedType::isDynamic(destShape[destPos])) {
4510  continue;
4511  }
4512  int64_t size = srcShape[srcPos];
4513  if (ShapedType::isDynamic(size))
4514  size = destShape[destPos];
4515  srcShape[srcPos] = size;
4516  destShape[destPos] = size;
4517  changeNeeded = true;
4518  }
4519  return changeNeeded;
4520 }
4521 
4522 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4523  PatternRewriter &rewriter) {
4524  /// unpack(pack(x)) -> x
4525  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4526  if (packOp.getSourceType() != unPackOp.getDestType())
4527  return failure();
4528  if (packOp.getPaddingValue() ||
4529  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4530  !haveSameTiles(packOp, unPackOp))
4531  return failure();
4532  rewriter.replaceOp(unPackOp, packOp.getSource());
4533  return success();
4534  }
4535  /// unpack(destinationStyleOp(x)) -> unpack(x)
4536  if (auto dstStyleOp =
4537  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4538  auto destValue = cast<OpResult>(unPackOp.getDest());
4539  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4540  rewriter.modifyOpInPlace(unPackOp,
4541  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4542  return success();
4543  }
4544 
4545  // Insert tensor.cast ops if static shape inference is available..
4546  SmallVector<int64_t> srcShape, destShape;
4547  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4548  Location loc = unPackOp.getLoc();
4549  Value source = unPackOp.getSource();
4550  if (srcShape != unPackOp.getSourceType().getShape()) {
4551  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4552  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4553  unPackOp.getSource());
4554  }
4555  Value dest = unPackOp.getDest();
4556  if (destShape != unPackOp.getDestType().getShape()) {
4557  auto newDestType = unPackOp.getDestType().clone(destShape);
4558  dest =
4559  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4560  }
4561  Value newOp = rewriter.create<tensor::UnPackOp>(
4562  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4563  unPackOp.getOuterDimsPerm());
4564  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4565  unPackOp, unPackOp.getResult().getType(), newOp);
4566  return success();
4567  }
4568 
4569  return failure();
4570 }
4571 
4572 bool UnPackOp::isLikeUnPad() {
4573  RankedTensorType packedTensorType = getSourceType();
4574  return isLikePadUnPad(*this, packedTensorType);
4575 }
4576 
4577 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4578  if (OpFoldResult reshapedSource = reshapeConstantSource(
4579  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4580  getResult().getType()))
4581  return reshapedSource;
4582  return {};
4583 }
4584 
4585 //===----------------------------------------------------------------------===//
4586 // Common Canonicalizers and Folders.
4587 //===----------------------------------------------------------------------===//
4588 
4589 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4590 /// the `tensor.cast` has source that is more static than the consuming op.
4591 ///
4592 /// Example:
4593 /// ```mlir
4594 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4595 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4596 /// ```
4597 ///
4598 /// folds into:
4599 ///
4600 /// ```mlir
4601 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4602 /// ```
4603 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4604 /// can add the pattern to their canonicalizers.
4606  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4608  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4609 
4610  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4611  PatternRewriter &rewriter) const override {
4612  // InsertSliceOp has its own logic about folding tensor.cast ops.
4613  if (isa<InsertSliceOp>(op.getOperation()))
4614  return failure();
4615 
4616  // Exclude DPS ops that are also LoopLike from this interface as they
4617  // might need special handling of attached regions.
4618  if (isa<LoopLikeOpInterface>(op.getOperation()))
4619  return failure();
4620 
4621  // If no operand comes from a tensor::CastOp and can be folded then fail.
4622  bool hasTensorCastOperand =
4623  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4624  if (llvm::isa<BlockArgument>(opOperand.get()))
4625  return false;
4626  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4627  return castOp && canFoldIntoConsumerOp(castOp);
4628  });
4629  if (!hasTensorCastOperand)
4630  return failure();
4631 
4632  SmallVector<Type, 4> newResultTypes(op->getResultTypes());
4633  SmallVector<Value, 4> newOperands;
4634  newOperands.reserve(op->getNumOperands());
4635  // Assumes that the result has dpsInits followed by nonDpsInits.
4636  int64_t dpsInitIdx = 0;
4637  for (OpOperand &opOperand : op->getOpOperands()) {
4638  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4639  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4640  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4641  if (op.isDpsInit(&opOperand) &&
4642  !llvm::isa<MemRefType>(newOperands.back().getType()))
4643  newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4644  }
4645 
4646  // Clone op.
4647  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4648  SmallVector<Value, 4> replacements;
4649  replacements.reserve(newOp->getNumResults());
4650  for (auto [oldResult, newResult] :
4651  llvm::zip(op->getResults(), newOp->getResults())) {
4652  if (newResult.getType() != oldResult.getType()) {
4653  replacements.push_back(rewriter.create<tensor::CastOp>(
4654  op->getLoc(), oldResult.getType(), newResult));
4655  } else {
4656  replacements.push_back(newResult);
4657  }
4658  }
4659  rewriter.replaceOp(op, replacements);
4660 
4661  return success();
4662  }
4663 };
4664 
4665 //===----------------------------------------------------------------------===//
4666 // TensorDialect
4667 //===----------------------------------------------------------------------===//
4668 
4669 void TensorDialect::getCanonicalizationPatterns(
4670  RewritePatternSet &results) const {
4672 }
4673 
4674 //===----------------------------------------------------------------------===//
4675 // TableGen'd op method definitions
4676 //===----------------------------------------------------------------------===//
4677 
4678 #define GET_OP_CLASSES
4679 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:4182
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:386
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3910
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1291
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2171
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3730
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3761
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:4240
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2897
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:4063
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:4048
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3745
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2592
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2444
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2466
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4254
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3791
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1549
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4227
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2553
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4338
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3806
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:179
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3718
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3774
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1077
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2893
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:135
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4213
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2612
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1773
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:954
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
UnitAttr getUnitAttr()
Definition: Builders.cpp:126
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:195
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:387
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:140
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:383
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:83
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:417
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:258
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:269
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:99
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:59
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:357
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2404
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:348
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2870
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2491
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:123
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:74
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:266
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:384
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1202
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4606
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4610
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2423
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2424
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2411
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2412
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)