MLIR  22.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/MathExtras.h"
39 #include <optional>
40 
41 using namespace mlir;
42 using namespace mlir::tensor;
43 
44 using llvm::divideCeilSigned;
45 using llvm::divideFloorSigned;
46 using llvm::mod;
47 
48 /// Materialize a single constant operation from a given attribute value with
49 /// the desired resultant type.
51  Attribute value, Type type,
52  Location loc) {
53  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
54  return op;
55  if (complex::ConstantOp::isBuildableWith(value, type))
56  return complex::ConstantOp::create(builder, loc, type,
57  llvm::cast<ArrayAttr>(value));
58  return nullptr;
59 }
60 
62  int64_t dim) {
63  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
64  if (tensorType.isDynamicDim(dim))
65  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
66 
67  return builder.getIndexAttr(tensorType.getDimSize(dim));
68 }
69 
71  Location loc, Value value) {
72  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
74  for (int64_t i = 0; i < tensorType.getRank(); ++i)
75  result.push_back(getMixedSize(builder, loc, value, i));
76  return result;
77 }
78 
80  OpResult opResult) {
81  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
82  assert(tensorType && "expected tensor type");
83 
84  // If the op has a destination, it implements DestinationStyleOpInterface and
85  // we can query the destination operand from that interface.
86  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
87  if (destOp)
88  return destOp.getTiedOpOperand(opResult)->get();
89 
90  // Otherwise, create a new destination tensor with the same shape.
92  b.setInsertionPoint(opResult.getDefiningOp());
93 
94  // Compute sizes.
95  SmallVector<OpFoldResult> mixedSizes;
96  if (!tensorType.hasStaticShape()) {
97  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
98  ReifiedRankedShapedTypeDims reifiedShapes;
99  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
100  return failure();
101  mixedSizes = reifiedShapes[opResult.getResultNumber()];
102  } else {
103  // Static shape: Take static sizes directly.
104  for (int64_t sz : tensorType.getShape())
105  mixedSizes.push_back(b.getIndexAttr(sz));
106  }
107 
108  // Create empty tensor.
109  Value emptyTensor =
110  tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
111  return emptyTensor;
112 }
113 
115  Operation *op,
116  SmallVector<Value> &result) {
117  for (OpResult opResult : op->getResults()) {
118  if (llvm::isa<TensorType>(opResult.getType())) {
119  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
120  if (failed(destination))
121  return failure();
122  result.push_back(*destination);
123  }
124  }
125  return success();
126 }
127 
129  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
130  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
131  return rtp1.getShape() == rtp2.getShape() &&
132  rtp1.getElementType() == rtp2.getElementType();
133  return false;
134  }
135  return tp1 == tp2; // default implementation
136 }
137 
138 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
139 /// rank-extending tensor.insert_slice op.
140 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
141  ArrayRef<OpFoldResult> mixedSizes) {
142  llvm::SmallBitVector droppedDims(mixedSizes.size());
143  int64_t shapePos = reducedShape.size() - 1;
144 
145  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
146  size_t idx = mixedSizes.size() - size.index() - 1;
147  // Rank-reduced dims must have a static unit dimension.
148  bool isStaticUnitSize =
149  isa<Attribute>(size.value()) &&
150  llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 
152  if (shapePos < 0) {
153  // There are no more dims in the reduced shape. All remaining sizes must
154  // be rank-reduced dims.
155  assert(isStaticUnitSize && "expected unit dim");
156  droppedDims.set(idx);
157  continue;
158  }
159 
160  // Dim is preserved if the size is not a static 1.
161  if (!isStaticUnitSize) {
162  --shapePos;
163  continue;
164  }
165 
166  // Dim is preserved if the reduced shape dim is also 1.
167  if (reducedShape[shapePos] == 1) {
168  --shapePos;
169  continue;
170  }
171 
172  // Otherwise: Dim is dropped.
173  droppedDims.set(idx);
174  }
175 
176  assert(shapePos < 0 && "dimension mismatch");
177  return droppedDims;
178 }
179 
180 /// Given a ranked tensor type and a range of values that defines its dynamic
181 /// dimension sizes, turn all dynamic sizes that have a constant value into
182 /// static dimension sizes.
183 static RankedTensorType
184 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
185  SmallVector<Value> &foldedDynamicSizes) {
186  SmallVector<int64_t> staticShape(type.getShape());
187  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
188  "incorrect number of dynamic sizes");
189 
190  // Compute new static and dynamic sizes.
191  unsigned ctr = 0;
192  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
193  if (type.isDynamicDim(i)) {
194  Value dynamicSize = dynamicSizes[ctr++];
195  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
196  if (cst.has_value()) {
197  // Dynamic size must be non-negative.
198  if (cst.value() < 0) {
199  foldedDynamicSizes.push_back(dynamicSize);
200  continue;
201  }
202  staticShape[i] = *cst;
203  } else {
204  foldedDynamicSizes.push_back(dynamicSize);
205  }
206  }
207  }
208 
209  return RankedTensorType::get(staticShape, type.getElementType(),
210  type.getEncoding());
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // BitcastOp
215 //===----------------------------------------------------------------------===//
216 
217 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
218  if (inputs.size() != 1 || outputs.size() != 1)
219  return false;
220  Type a = inputs.front(), b = outputs.front();
221  auto aT = dyn_cast<TensorType>(a);
222  auto bT = dyn_cast<TensorType>(b);
223  if (!aT || !bT)
224  return false;
225 
226  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
227  return false;
228 
229  return succeeded(verifyCompatibleShape(aT, bT));
230 }
231 
232 namespace {
233 
234 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
235 /// operation.
236 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
238 
239  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
240  PatternRewriter &rewriter) const final {
241  auto tensorBitcastOperand =
242  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
243  if (!tensorBitcastOperand)
244  return failure();
245 
246  auto resultType = cast<TensorType>(tensorBitcast.getType());
247  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
248  tensorBitcastOperand.getOperand());
249  return success();
250  }
251 };
252 
253 } // namespace
254 
255 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
256  MLIRContext *context) {
257  results.add<ChainedTensorBitcast>(context);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // CastOp
262 //===----------------------------------------------------------------------===//
263 
264 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
265  setNameFn(getResult(), "cast");
266 }
267 
268 /// Returns true if `target` is a ranked tensor type that preserves static
269 /// information available in the `source` ranked tensor type.
271  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
272  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
273 
274  // Requires RankedTensorType.
275  if (!sourceType || !targetType)
276  return false;
277 
278  // Requires same elemental type.
279  if (sourceType.getElementType() != targetType.getElementType())
280  return false;
281 
282  // Requires same rank.
283  if (sourceType.getRank() != targetType.getRank())
284  return false;
285 
286  // Requires same encoding.
287  if (sourceType.getEncoding() != targetType.getEncoding())
288  return false;
289 
290  // If cast is towards more static sizes along any dimension, don't fold.
291  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
292  if (ShapedType::isStatic(std::get<0>(t)) &&
293  ShapedType::isDynamic(std::get<1>(t)))
294  return false;
295  }
296 
297  return true;
298 }
299 
300 /// Determines whether tensor::CastOp casts to a more dynamic version of the
301 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
302 /// implement canonicalization patterns for ops in different dialects that may
303 /// consume the results of tensor.cast operations. Such foldable tensor.cast
304 /// operations are typically inserted as `slice` ops and are canonicalized,
305 /// to preserve the type compatibility of their uses.
306 ///
307 /// Returns true when all conditions are met:
308 /// 1. source and result are ranked tensors with same element type and rank.
309 /// 2. the tensor type has more static information than the result
310 ///
311 /// Example:
312 /// ```mlir
313 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
314 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
315 /// ```
316 ///
317 /// folds into:
318 ///
319 /// ```mlir
320 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
321 /// ```
323  if (!castOp)
324  return false;
325 
326  // Can fold if the source of cast has at least as much static information as
327  // its results.
328  return preservesStaticInformation(castOp.getType(),
329  castOp.getSource().getType());
330 }
331 
332 /// Determines whether the tensor::CastOp casts to a more static version of the
333 /// source tensor. This is useful to fold into a producing op and implement
334 /// canonicalization patterns with the `tensor.cast` op as the root, but
335 /// producer being from different dialects. Returns true when all conditions are
336 /// met:
337 /// 1. source and result and ranked tensors with same element type and rank.
338 /// 2. the result type has more static information than the source.
339 ///
340 /// Example:
341 /// ```mlir
342 /// %1 = producer ... : tensor<?x?xf32>
343 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
344 /// ```
345 ///
346 /// can be canonicalized to :
347 ///
348 /// ```mlir
349 /// %2 = producer ... : tensor<8x16xf32>
350 /// ```
351 /// Not all ops might be canonicalizable this way, but for those that can be,
352 /// this method provides a check that it is worth doing the canonicalization.
354  if (!castOp)
355  return false;
356  return preservesStaticInformation(castOp.getSource().getType(),
357  castOp.getType());
358 }
359 
361  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
362  if (llvm::isa<BlockArgument>(opOperand.get()))
363  return false;
364  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
365  return castOp && canFoldIntoConsumerOp(castOp);
366  });
367 }
368 
370  DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
371  SmallVector<Value> newOperands;
372  newOperands.reserve(op->getNumOperands());
373 
374  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
375 
376  // Assumes that the result has dpsInits followed by nonDpsInits.
377  int64_t dpsInitIdx = 0;
378  for (OpOperand &opOperand : op->getOpOperands()) {
379  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
380  bool fold = canFoldIntoConsumerOp(tensorCastOp);
381  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
382  if (op.isDpsInit(&opOperand) &&
383  !llvm::isa<MemRefType>(newOperands.back().getType()))
384  newResTy[dpsInitIdx++] = newOperands.back().getType();
385  }
386  return newOperands;
387 }
388 
389 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
390 /// that can be folded.
392  bool folded = false;
393  for (OpOperand &operand : op->getOpOperands()) {
394  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
395  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
396  operand.set(castOp.getOperand());
397  folded = true;
398  }
399  }
400  return success(folded);
401 }
402 
403 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
404  if (inputs.size() != 1 || outputs.size() != 1)
405  return false;
406  Type a = inputs.front(), b = outputs.front();
407  auto aT = llvm::dyn_cast<TensorType>(a);
408  auto bT = llvm::dyn_cast<TensorType>(b);
409  if (!aT || !bT)
410  return false;
411 
412  if (aT.getElementType() != bT.getElementType())
413  return false;
414 
415  return succeeded(verifyCompatibleShape(aT, bT));
416 }
417 
418 /// Compute a TensorType that has the joined shape knowledge of the two
419 /// given TensorTypes. The element types need to match.
421  assert(one.getElementType() == two.getElementType());
422 
423  if (!one.hasRank())
424  return two;
425  if (!two.hasRank())
426  return one;
427 
428  int64_t rank = one.getRank();
429  if (rank != two.getRank())
430  return {};
431 
433  join.reserve(rank);
434  for (int64_t i = 0; i < rank; ++i) {
435  if (one.isDynamicDim(i)) {
436  join.push_back(two.getDimSize(i));
437  continue;
438  }
439  if (two.isDynamicDim(i)) {
440  join.push_back(one.getDimSize(i));
441  continue;
442  }
443  if (one.getDimSize(i) != two.getDimSize(i))
444  return {};
445  join.push_back(one.getDimSize(i));
446  }
447  return RankedTensorType::get(join, one.getElementType());
448 }
449 
450 namespace {
451 
452 /// Replaces chains of two tensor.cast operations by a single tensor.cast
453 /// operation if doing so does not remove runtime constraints.
454 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
456 
457  LogicalResult matchAndRewrite(CastOp tensorCast,
458  PatternRewriter &rewriter) const final {
459  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
460 
461  if (!tensorCastOperand)
462  return failure();
463 
464  auto sourceType =
465  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
466  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
467  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
468 
469  // We can remove the intermediate cast if joining all three produces the
470  // same result as just joining the source and result shapes.
471  auto firstJoin =
472  joinShapes(joinShapes(sourceType, intermediateType), resultType);
473 
474  // The join might not exist if the cast sequence would fail at runtime.
475  if (!firstJoin)
476  return failure();
477 
478  // The newJoin always exists if the above join exists, it might just contain
479  // less information. If so, we cannot drop the intermediate cast, as doing
480  // so would remove runtime checks.
481  auto newJoin = joinShapes(sourceType, resultType);
482  if (firstJoin != newJoin)
483  return failure();
484 
485  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
486  tensorCastOperand.getOperand());
487  return success();
488  }
489 };
490 
491 /// Fold tensor.cast into tesor.extract_slice producer.
492 /// Example:
493 /// ```
494 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
495 /// tensor<128x512xf32> to tensor<?x512xf32>
496 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
497 /// ```
498 /// ->
499 /// ```
500 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
501 /// tensor<128x512xf32> to tensor<16x512xf32>
502 /// ```
503 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
505 
506  LogicalResult matchAndRewrite(CastOp tensorCast,
507  PatternRewriter &rewriter) const final {
508  auto extractOperand =
509  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510 
511  // Cannot fold cast to unranked tensor.
512  auto rankedResultType =
513  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
514  if (!rankedResultType)
515  return failure();
516 
517  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
518  rankedResultType.getShape() ==
519  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
520  .getShape())
521  return failure();
522 
523  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
524  auto dimMask = computeRankReductionMask(
525  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
526  size_t dimIndex = 0;
527  for (size_t i = 0, e = sizes.size(); i < e; i++) {
528  if (dimMask && dimMask->count(i))
529  continue;
530  int64_t dim = rankedResultType.getShape()[dimIndex++];
531  if (ShapedType::isDynamic(dim))
532  continue;
533  sizes[i] = rewriter.getIndexAttr(dim);
534  }
535 
536  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
537  tensorCast, rankedResultType, extractOperand.getSource(),
538  extractOperand.getMixedOffsets(), sizes,
539  extractOperand.getMixedStrides());
540  return success();
541  }
542 };
543 
544 } // namespace
545 
546 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
547  MLIRContext *context) {
548  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // ConcatOp
553 //===----------------------------------------------------------------------===//
554 
555 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
556  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
557  auto tensorTypes =
558  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
559  return llvm::cast<RankedTensorType>(type);
560  }));
561  int64_t concatRank = tensorTypes[0].getRank();
562 
563  // The concatenation dim must be in the range [0, rank).
564  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
565 
566  SmallVector<int64_t> sizes(concatRank);
567  for (int64_t i = 0, e = concatRank; i < e; ++i) {
568  if (i == dim)
569  continue;
570  SaturatedInteger size;
571  for (auto tensorType : tensorTypes)
572  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
573  sizes[i] = size.asInteger();
574  }
575  auto concatSize = SaturatedInteger::wrap(0);
576  for (auto tensorType : tensorTypes)
577  concatSize =
578  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
579  sizes[dim] = concatSize.asInteger();
580  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
581 }
582 
583 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
584  ValueRange inputs) {
585  FailureOr<RankedTensorType> resultType =
586  inferResultType(dim, inputs.getTypes());
587  assert(succeeded(resultType) && "failed to infer concatenation result type");
588  build(builder, result, *resultType, dim, inputs);
589 }
590 
591 LogicalResult ConcatOp::verify() {
592  if (getInputs().size() < 1)
593  return emitOpError("requires at least one input");
594 
596  for (auto input : getInputs())
597  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
598 
599  RankedTensorType resultType = getResultType();
600  int64_t resultRank = getRank();
601  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602  return type.getRank() != resultRank;
603  }))
604  return emitOpError("rank of concatenated inputs must match result rank");
605 
606  Type resultElementType = resultType.getElementType();
607  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608  return type.getElementType() != resultElementType;
609  }))
610  return emitOpError("inputs and result element type must match");
611 
612  int64_t dim = getDim();
613  if (dim >= resultRank)
614  return emitOpError("concatenation dim must be less than the tensor rank");
615 
616  SmallVector<int64_t> sizes(resultRank);
617  for (int64_t i = 0, e = resultRank; i < e; ++i) {
618  if (i == dim)
619  continue;
620  SaturatedInteger size;
621  for (auto tensorType : inputTypes) {
622  FailureOr<SaturatedInteger> maybeSize =
623  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
624  if (failed(maybeSize))
625  return emitOpError("static concatenation size mismatch along ")
626  << "non-concatenated dimension " << i;
627  size = *maybeSize;
628  }
629  sizes[i] = size.asInteger();
630  }
631  auto concatSize = SaturatedInteger::wrap(0);
632  for (auto tensorType : inputTypes)
633  concatSize =
634  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
635  sizes[dim] = concatSize.asInteger();
636  auto inferredResultType =
637  RankedTensorType::get(sizes, inputTypes[0].getElementType());
638 
639  for (auto [inferredSize, actualSize] :
640  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642  ShapedType::isDynamic(actualSize);
643  if (!hasDynamic && inferredSize != actualSize)
644  return emitOpError("result type ")
645  << resultType << "does not match inferred shape "
646  << inferredResultType << " static sizes";
647  }
648 
649  return success();
650 }
651 
652 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
653  size_t numInputs = getInputs().size();
654  uint64_t concatDim = getDim();
655 
657  inputShapes.reserve(numInputs);
658  SmallVector<OpFoldResult> concatOffsets;
659  concatOffsets.reserve(numInputs);
660  SmallVector<OpFoldResult> outputShape;
661 
662  AffineExpr addExpr =
663  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
664  OpFoldResult zero = builder.getIndexAttr(0);
665  Location loc = getLoc();
666  for (auto [index, input] : llvm::enumerate(getInputs())) {
667  SmallVector<OpFoldResult> inputShape =
668  tensor::getMixedSizes(builder, input.getLoc(), input);
669  if (index == 0) {
670  outputShape = inputShape;
671  concatOffsets.push_back(zero);
672  } else {
673  concatOffsets.push_back(outputShape[concatDim]);
674  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
675  builder, loc, addExpr,
676  {outputShape[concatDim], inputShape[concatDim]});
677  }
678  inputShapes.emplace_back(std::move(inputShape));
679  }
680 
681  Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
682  getType().getElementType());
683 
684  int64_t rank = getType().getRank();
685  OpFoldResult one = builder.getIndexAttr(1);
686  SmallVector<OpFoldResult> strides(rank, one);
687  SmallVector<OpFoldResult> offsets(rank, zero);
688  for (auto [index, input] : llvm::enumerate(getInputs())) {
689  offsets[concatDim] = concatOffsets[index];
690  auto insertSlice = tensor::InsertSliceOp::create(
691  builder, loc, input, replacement, offsets, inputShapes[index], strides);
692  replacement = insertSlice.getResult();
693  }
694  if (replacement.getType() != getType()) {
695  replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
696  }
697  return SmallVector<Value>{replacement};
698 }
699 
700 LogicalResult
702  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
703  ValueRange inputs = getInputs();
704  int64_t dim = getDim();
705  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
706 
707  Value init = inputs[0];
708  int64_t rank = getType().getRank();
709 
710  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
711 
712  // Pre-populate the result sizes with as much static information as possible
713  // from the given result type, as well as the inferred result type, otherwise
714  // use the dim sizes from the first input.
715  for (int64_t i = 0; i < rank; ++i) {
716  if (i == dim)
717  continue;
718  if (!getType().isDynamicDim(i)) {
719  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
720  } else if (!inferredResultType.isDynamicDim(i)) {
721  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
722  builder, getLoc(),
723  builder.getIndexAttr(inferredResultType.getDimSize(i)));
724  } else {
725  reifiedReturnShapes[0][i] =
726  tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
727  }
728  }
729 
730  if (getType().isDynamicDim(dim)) {
731  // Take the sum of the input sizes along the concatenated dim.
732  AffineExpr sum = builder.getAffineDimExpr(0);
733  SmallVector<OpFoldResult> sizes = {
734  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
735  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
736  sum = sum + builder.getAffineDimExpr(idx + 1);
737  sizes.push_back(
738  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
739  }
740  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
741  builder, getLoc(),
742  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
743  } else {
744  // If the result shape is static along the concatenated dim, use the static
745  // shape.
746  reifiedReturnShapes[0][dim] =
747  builder.getIndexAttr(getType().getDimSize(dim));
748  }
749  return success();
750 }
751 
752 void ConcatOp::getAsmResultNames(
753  function_ref<void(Value, StringRef)> setNameFn) {
754  setNameFn(getResult(), "concat");
755 }
756 
757 OpFoldResult ConcatOp::fold(FoldAdaptor) {
758  ValueRange inputs = getInputs();
759  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
760  return inputs[0];
761  return {};
762 }
763 
764 namespace {
765 /// Fold a concat op with a single input to a cast.
766 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
768 
769  LogicalResult matchAndRewrite(ConcatOp concatOp,
770  PatternRewriter &rewriter) const override {
771  if (concatOp.getInputs().size() != 1)
772  return failure();
773  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
774  concatOp.getInputs()[0]);
775  return success();
776  }
777 };
778 
779 /// Propagate static shapes into the operands of a `tensor.concat`.
780 ///
781 /// `tensor.concat` requires every operand to match on all dimensions except the
782 /// concatenation dimension. If one operand is already static in those
783 /// dimensions, the other operands may safely be refined to that same static
784 /// shape.
785 ///
786 /// Example:
787 ///
788 /// ```mlir
789 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790 /// tensor<?x12xi32>
791 /// ```
792 /// ->
793 /// ```mlir
794 /// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795 /// %2 = tensor.concat dim(0) %0, %cast :
796 /// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797 /// ```
798 struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800 
801  LogicalResult matchAndRewrite(ConcatOp concatOp,
802  PatternRewriter &rewriter) const override {
803  int64_t dim = concatOp.getDim();
804  RankedTensorType inferredResultType =
805  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
806 
807  // Find operands for which a more static shape can be inferred.
808  LogicalResult matched = failure();
809  // Inferred operand shapes are identical in every dimension except the
810  // concatenation dimension.
811  SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812  for (auto [operandIdx, operandType] :
813  llvm::enumerate(concatOp->getOperandTypes())) {
814  // Compute inferred type for operand.
815  inferredOperandShape[dim] =
816  cast<RankedTensorType>(operandType).getDimSize(dim);
817  auto inferredOperandType = RankedTensorType::get(
818  inferredOperandShape, inferredResultType.getElementType());
819 
820  // Check if inferred type is more static.
821  if (!preservesStaticInformation(inferredOperandType, operandType)) {
822  matched = success();
823 
824  // Use refined operand type and create cast from original operand.
825  auto castOp =
826  CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
827  concatOp.getOperand(operandIdx));
828  rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
829  concatOp->setOperand(operandIdx, castOp->getResult(0));
830  });
831  }
832  }
833 
834  return matched;
835  }
836 };
837 
838 // Ensure `tensor.concat`'s result type is at least as static as can be inferred
839 // from its operand types.
840 ///
841 /// Example:
842 /// ```mlir
843 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
844 /// tensor<?x?xi32>
845 /// ```
846 /// ->
847 /// ```mlir
848 /// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
849 /// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
850 /// tensor<?x?xi32>
851 /// ```
852 struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
854 
855  LogicalResult matchAndRewrite(ConcatOp concatOp,
856  PatternRewriter &rewriter) const override {
857  int64_t dim = concatOp.getDim();
858  RankedTensorType inferredResultType =
859  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
860 
861  // The result type should be at least as static as inferred result type.
862  if (preservesStaticInformation(inferredResultType,
863  concatOp.getResultType())) {
864  return failure();
865  }
866 
867  auto newConcatOp =
868  ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
869  concatOp->getOperands());
870  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
871  newConcatOp);
872 
873  return success();
874  }
875 };
876 } // namespace
877 
878 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
879  MLIRContext *context) {
880  results
881  .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
882  context);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // DimOp
887 //===----------------------------------------------------------------------===//
888 
889 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
890  setNameFn(getResult(), "dim");
891 }
892 
893 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
894  int64_t index) {
895  auto loc = result.location;
896  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
897  build(builder, result, source, indexValue);
898 }
899 
900 std::optional<int64_t> DimOp::getConstantIndex() {
901  return getConstantIntValue(getIndex());
902 }
903 
904 Speculation::Speculatability DimOp::getSpeculatability() {
905  auto constantIndex = getConstantIndex();
906  if (!constantIndex)
908 
909  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
910  if (!rankedSourceType)
912 
913  if (rankedSourceType.getRank() <= constantIndex)
915 
917 }
918 
919 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920  SetIntLatticeFn setResultRange) {
921  setResultRange(getResult(),
922  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923 }
924 
925 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
926  // All forms of folding require a known index.
927  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928  if (!index)
929  return {};
930 
931  // Folding for unranked types (UnrankedTensorType) is not supported.
932  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
933  if (!tensorType)
934  return {};
935 
936  // Out of bound indices produce undefined behavior but are still valid IR.
937  // Don't choke on them.
938  int64_t indexVal = index.getInt();
939  if (indexVal < 0 || indexVal >= tensorType.getRank())
940  return {};
941 
942  // Fold if the shape extent along the given index is known.
943  if (!tensorType.isDynamicDim(index.getInt())) {
944  Builder builder(getContext());
945  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
946  }
947 
948  Operation *definingOp = getSource().getDefiningOp();
949 
950  // Fold dim to the operand of tensor.generate.
951  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
952  auto resultType =
953  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
954  // The case where the type encodes the size of the dimension is handled
955  // above.
956  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
957 
958  // Find the operand of the fromElements that corresponds to this index.
959  auto dynExtents = fromElements.getDynamicExtents().begin();
960  for (auto dim : resultType.getShape().take_front(index.getInt()))
961  if (ShapedType::isDynamic(dim))
962  dynExtents++;
963 
964  return Value{*dynExtents};
965  }
966 
967  // The size at the given index is now known to be a dynamic size.
968  unsigned unsignedIndex = index.getValue().getZExtValue();
969 
970  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
971  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
972  // `resolve-shaped-type-result-dims` pass.
973  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
974  sliceOp.isDynamicSize(unsignedIndex)) {
975  return {sliceOp.getDynamicSize(unsignedIndex)};
976  }
977  }
978 
979  // dim(cast) -> dim
980  if (succeeded(foldTensorCast(*this)))
981  return getResult();
982 
983  return {};
984 }
985 
986 namespace {
987 /// Fold dim of a cast into the dim of the source of the tensor cast.
988 struct DimOfCastOp : public OpRewritePattern<DimOp> {
990 
991  LogicalResult matchAndRewrite(DimOp dimOp,
992  PatternRewriter &rewriter) const override {
993  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
994  if (!castOp)
995  return failure();
996  Value newSource = castOp.getOperand();
997  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
998  return success();
999  }
1000 };
1001 
1002 /// Fold dim of a destination passing style op into the dim of the corresponding
1003 /// init.
1004 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1006 
1007  LogicalResult matchAndRewrite(DimOp dimOp,
1008  PatternRewriter &rewriter) const override {
1009  auto source = dimOp.getSource();
1010  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1011  if (!destOp)
1012  return failure();
1013 
1014  auto resultIndex = cast<OpResult>(source).getResultNumber();
1015  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1016 
1017  rewriter.modifyOpInPlace(
1018  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1019  return success();
1020  }
1021 };
1022 
1023 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1024 /// operand.
1025 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1027 
1028  LogicalResult matchAndRewrite(DimOp dim,
1029  PatternRewriter &rewriter) const override {
1030  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1031 
1032  if (!reshape)
1033  return failure();
1034 
1035  // Since tensors are immutable we don't need to worry about where to place
1036  // the extract call
1037  rewriter.setInsertionPointAfter(dim);
1038  Location loc = dim.getLoc();
1039  Value extract =
1040  ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1041  if (extract.getType() != dim.getType())
1042  extract =
1043  arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1044  rewriter.replaceOp(dim, extract);
1045  return success();
1046  }
1047 };
1048 } // namespace
1049 
1050 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1051  MLIRContext *context) {
1052  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1053 }
1054 
1055 //===----------------------------------------------------------------------===//
1056 // EmptyOp
1057 //===----------------------------------------------------------------------===//
1058 
1059 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1060  ArrayRef<int64_t> staticShape, Type elementType,
1061  Attribute encoding) {
1062  assert(none_of(staticShape, ShapedType::isDynamic) &&
1063  "expected only static sizes");
1064  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1065 }
1066 
1067 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1068  ArrayRef<int64_t> staticShape, Type elementType,
1069  ValueRange dynamicSizes, Attribute encoding) {
1070  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1071  build(builder, result, tensorType, dynamicSizes);
1072 }
1073 
1074 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1075  ArrayRef<OpFoldResult> sizes, Type elementType,
1076  Attribute encoding) {
1077  SmallVector<int64_t> staticShape;
1078  SmallVector<Value> dynamicSizes;
1079  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1080  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1081 }
1082 
1083 LogicalResult EmptyOp::verify() {
1084  if (getType().getNumDynamicDims() != getDynamicSizes().size())
1085  return emitOpError("incorrect number of dynamic sizes, has ")
1086  << getDynamicSizes().size() << ", expected "
1087  << getType().getNumDynamicDims();
1088  return success();
1089 }
1090 
1091 LogicalResult
1093  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1094  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1095  unsigned ctr = 0;
1096  for (int64_t i = 0; i < getType().getRank(); ++i) {
1097  if (getType().isDynamicDim(i)) {
1098  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1099  } else {
1100  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1101  }
1102  }
1103  return success();
1104 }
1105 
1106 Value EmptyOp::getDynamicSize(unsigned idx) {
1107  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1108  unsigned ctr = 0;
1109  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1110  if (getType().isDynamicDim(i))
1111  ++ctr;
1112  return getDynamicSizes()[ctr];
1113 }
1114 
1117  unsigned ctr = 0;
1118  OpBuilder b(getContext());
1119  for (int64_t i = 0; i < getType().getRank(); ++i) {
1120  if (getType().isDynamicDim(i)) {
1121  result.push_back(getDynamicSizes()[ctr++]);
1122  } else {
1123  result.push_back(b.getIndexAttr(getType().getShape()[i]));
1124  }
1125  }
1126  return result;
1127 }
1128 
1129 namespace {
1130 /// Change the type of the result of a `tensor.empty` by making the result
1131 /// type statically sized along dimensions that in the original operation were
1132 /// defined as dynamic, but the size was defined using a `constant` op. For
1133 /// example
1134 ///
1135 /// %c5 = arith.constant 5: index
1136 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1137 ///
1138 /// to
1139 ///
1140 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1141 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1143 
1144  LogicalResult matchAndRewrite(EmptyOp op,
1145  PatternRewriter &rewriter) const override {
1146  SmallVector<Value> foldedDynamicSizes;
1147  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1148  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1149 
1150  // Stop here if no dynamic size was promoted to static.
1151  if (foldedTensorType == op.getType())
1152  return failure();
1153 
1154  auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1155  foldedDynamicSizes);
1156  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1157  return success();
1158  }
1159 };
1160 
1161 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1163 
1164  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1165  PatternRewriter &rewriter) const override {
1166  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1167  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1168  if (!emptyTensorOp || !maybeConstantIndex)
1169  return failure();
1170  auto emptyTensorType = emptyTensorOp.getType();
1171  if (*maybeConstantIndex < 0 ||
1172  *maybeConstantIndex >= emptyTensorType.getRank() ||
1173  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1174  return failure();
1175  rewriter.replaceOp(dimOp,
1176  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1177  return success();
1178  }
1179 };
1180 
1181 /// Canonicalize
1182 ///
1183 /// ```mlir
1184 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1185 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1186 /// ```
1187 ///
1188 /// into
1189 ///
1190 /// ```mlir
1191 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1192 /// ```
1193 ///
1194 /// This assumes the input program is correct in terms of its shape. So it is
1195 /// safe to assume that `%d0` is in fact 4.
1196 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1198 
1199  LogicalResult matchAndRewrite(CastOp castOp,
1200  PatternRewriter &rewriter) const override {
1201  if (!canFoldIntoProducerOp(castOp))
1202  return failure();
1203  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1204  if (!producer)
1205  return failure();
1206 
1207  auto resultType =
1208  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1209  ArrayRef<int64_t> resultShape = resultType.getShape();
1210  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1211  SmallVector<OpFoldResult> newMixedSizes;
1212  newMixedSizes.reserve(currMixedSizes.size());
1213  assert(resultShape.size() == currMixedSizes.size() &&
1214  "mismatch in result shape and sizes of empty op");
1215  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1216  int64_t newDim = std::get<0>(it);
1217  OpFoldResult currDim = std::get<1>(it);
1218  // Case 1: The empty tensor dim is static. Check that the tensor cast
1219  // result dim matches.
1220  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1221  if (ShapedType::isDynamic(newDim) ||
1222  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1223  // Something is off, the cast result shape cannot be more dynamic
1224  // than the empty tensor result shape (enforced by
1225  // `canFoldIntoProducer`). Abort for now.
1226  return rewriter.notifyMatchFailure(
1227  producer, "mismatch in static value of shape of empty tensor "
1228  "result and cast result");
1229  }
1230  newMixedSizes.push_back(attr);
1231  continue;
1232  }
1233 
1234  // Case 2 : The tensor cast shape is static, but empty tensor result
1235  // shape is dynamic.
1236  if (ShapedType::isStatic(newDim)) {
1237  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1238  continue;
1239  }
1240 
1241  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1242  // shape is dynamic. Use the dynamic value from the empty tensor op.
1243  newMixedSizes.push_back(currDim);
1244  }
1245 
1246  // TODO: Do not drop tensor encoding.
1247  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1248  resultType.getElementType());
1249  return success();
1250  }
1251 };
1252 
1253 } // namespace
1254 
1255 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1256  MLIRContext *context) {
1257  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1258  ReplaceEmptyTensorStaticShapeDims>(context);
1259 }
1260 
1261 //===----------------------------------------------------------------------===//
1262 // ExtractOp
1263 //===----------------------------------------------------------------------===//
1264 
1265 namespace {
1266 
1267 /// Canonicalizes the pattern of the form
1268 ///
1269 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1270 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1271 ///
1272 /// to
1273 ///
1274 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1275 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1277 
1278  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1279  PatternRewriter &rewriter) const final {
1280  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1281  if (!tensorCast)
1282  return failure();
1283  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1284  return failure();
1285  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1286  extract, tensorCast.getSource(), extract.getIndices());
1287  return success();
1288  }
1289 };
1290 
1291 /// Canonicalizes the pattern of the form
1292 ///
1293 /// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1294 /// tensor<12xf64>
1295 /// %extracted_element = tensor.extract %val[%c10] :
1296 /// tensor<12xf64>
1297 ///
1298 /// to
1299 ///
1300 /// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1301 struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1303 
1304  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1305  PatternRewriter &rewriter) const final {
1306  auto collapseOp =
1307  extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1308  if (!collapseOp)
1309  return failure();
1310  if (!collapseOp.getSrcType().hasStaticShape())
1311  return failure();
1312 
1313  auto sourceSizes = collapseOp.getSrcType().getShape();
1314 
1315  SmallVector<Value> indices(extractOp.getIndices().begin(),
1316  extractOp.getIndices().end());
1317  SmallVector<Value> sourceIndices;
1318  for (auto [index, group] :
1319  llvm::zip(indices, collapseOp.getReassociationIndices())) {
1320  assert(!group.empty() && "association indices groups cannot be empty");
1321  auto groupSize = group.size();
1322 
1323  if (groupSize == 1) {
1324  sourceIndices.push_back(index);
1325  continue;
1326  }
1327 
1328  SmallVector<int64_t> basis =
1329  llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1330  auto delinearize = affine::AffineDelinearizeIndexOp::create(
1331  rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1332  llvm::append_range(sourceIndices, delinearize.getResults());
1333  }
1334  if (collapseOp.getReassociationIndices().empty()) {
1335  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1336  int64_t srcRank =
1337  cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1339  rewriter, extractOp.getLoc(), zeroAffineMap,
1341  for (int64_t i = 0; i < srcRank; i++) {
1342  sourceIndices.push_back(
1343  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1344  }
1345  }
1346 
1347  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1348  extractOp, collapseOp.getSrc(), sourceIndices);
1349  return success();
1350  }
1351 };
1352 
1353 } // namespace
1354 
1355 void ExtractOp::getAsmResultNames(
1356  function_ref<void(Value, StringRef)> setNameFn) {
1357  setNameFn(getResult(), "extracted");
1358 }
1359 
1360 LogicalResult ExtractOp::verify() {
1361  // Verify the # indices match if we have a ranked type.
1362  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1363  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1364  return emitOpError("incorrect number of indices for extract_element");
1365  return success();
1366 }
1367 
1368 /// If we have an ExtractOp consuming an InsertOp with the same
1369 /// indices, we can return the InsertOp's scalar directly.
1370 // TODO: This only checks the immediate producer; extend to go up the
1371 // insert/extract chain if the slices are disjoint.
1372 static Value foldExtractAfterInsert(ExtractOp extractOp) {
1373  auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1374 
1375  auto isSame = [](Value a, Value b) {
1376  return getAsOpFoldResult(a) == getAsOpFoldResult(b);
1377  };
1378  if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1379  llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1380  return insertOp.getScalar();
1381 
1382  return {};
1383 }
1384 
1385 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1386  if (Attribute tensor = adaptor.getTensor()) {
1387  // If this is a splat elements attribute, simply return the value.
1388  // All of the elements of a splat attribute are the same.
1389  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1390  return splatTensor.getSplatValue<Attribute>();
1391 
1392  // If this is a dense resource elements attribute, return.
1393  if (isa<DenseResourceElementsAttr>(tensor))
1394  return {};
1395  }
1396 
1397  // Collect the constant indices into the tensor.
1398  SmallVector<uint64_t, 8> indices;
1399  for (Attribute indice : adaptor.getIndices()) {
1400  if (!indice || !llvm::isa<IntegerAttr>(indice))
1401  return {};
1402  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1403  }
1404 
1405  // Fold extract(from_elements(...)).
1406  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1407  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1408  auto rank = tensorType.getRank();
1409  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1410  "rank mismatch");
1411  int flatIndex = 0;
1412  int stride = 1;
1413  for (int i = rank - 1; i >= 0; --i) {
1414  flatIndex += indices[i] * stride;
1415  stride *= tensorType.getDimSize(i);
1416  }
1417  // Prevent out of bounds accesses. This can happen in invalid code that
1418  // will never execute.
1419  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1420  flatIndex < 0)
1421  return {};
1422  return fromElementsOp.getElements()[flatIndex];
1423  }
1424 
1425  // If this is an elements attribute, query the value at the given indices.
1426  if (Attribute tensor = adaptor.getTensor()) {
1427  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1428  if (elementsAttr && elementsAttr.isValidIndex(indices))
1429  return elementsAttr.getValues<Attribute>()[indices];
1430  }
1431 
1432  if (Value result = foldExtractAfterInsert(*this))
1433  return result;
1434 
1435  return {};
1436 }
1437 
1438 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1439  MLIRContext *context) {
1440  results.add<ExtractFromTensorCast>(context);
1441 }
1442 
1445  patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // FromElementsOp
1450 //===----------------------------------------------------------------------===//
1451 
1452 void FromElementsOp::getAsmResultNames(
1453  function_ref<void(Value, StringRef)> setNameFn) {
1454  setNameFn(getResult(), "from_elements");
1455 }
1456 
1457 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1458  ValueRange elements) {
1459  assert(!elements.empty() && "expected at least one element");
1460  Type resultType = RankedTensorType::get(
1461  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1462  build(builder, result, resultType, elements);
1463 }
1464 
1465 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1466  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1467  return DenseElementsAttr::get(getType(), adaptor.getElements());
1468  return {};
1469 }
1470 
1471 namespace {
1472 
1473 // Pushes the index_casts that occur before extractions to after the extract.
1474 // This minimizes type conversion in some cases and enables the extract
1475 // canonicalizer. This changes:
1476 //
1477 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1478 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1479 //
1480 // to the following:
1481 //
1482 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1483 // %cast = arith.index_cast %extract : i32 to index
1484 //
1485 // to just %element.
1486 //
1487 // Consider expanding this to a template and handle all tensor cast
1488 // operations.
1489 struct ExtractElementFromIndexCast
1490  : public OpRewritePattern<tensor::ExtractOp> {
1492 
1493  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1494  PatternRewriter &rewriter) const final {
1495  Location loc = extract.getLoc();
1496  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1497  if (!indexCast)
1498  return failure();
1499 
1500  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1501 
1502  auto newExtract = tensor::ExtractOp::create(
1503  rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1504 
1505  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1506  newExtract);
1507 
1508  return success();
1509  }
1510 };
1511 
1512 } // namespace
1513 
1514 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1515  MLIRContext *context) {
1516  results.add<ExtractElementFromIndexCast>(context);
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // GatherOp
1521 //===----------------------------------------------------------------------===//
1522 
1523 void GatherOp::getAsmResultNames(
1524  function_ref<void(Value, StringRef)> setNameFn) {
1525  setNameFn(getResult(), "gather");
1526 }
1527 
1528 /// Return the inferred result type for a gatherOp where:
1529 /// - sourceType is the type of the source tensor gathered from
1530 /// - indicesType is the type of the indices used to gather
1531 /// - gatherDims are the dims along which the gather occurs.
1532 /// Return a full rank or ranked-reduced variant of the type depending on
1533 /// the value of rankReduced.
1534 ///
1535 /// The leading dimensions of the index tensor give the result tensor its
1536 /// leading dimensions.
1537 /// The trailing dimensions of the result tensor are obtained from the source
1538 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1539 /// rankedReduced is false), or skipping them (otherwise).
1540 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1541  RankedTensorType indicesType,
1542  ArrayRef<int64_t> gatherDims,
1543  bool rankReduced) {
1544  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1545  resultShape.reserve(resultShape.size() + sourceType.getRank());
1546  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1547  if (llvm::binary_search(gatherDims, idx)) {
1548  if (!rankReduced)
1549  resultShape.push_back(1);
1550  continue;
1551  }
1552  resultShape.push_back(sourceType.getDimSize(idx));
1553  }
1554  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1555 }
1556 
1557 static LogicalResult
1559  ArrayRef<int64_t> indices, int64_t rank,
1560  StringRef gatherOrScatter, StringRef sourceOrDest) {
1561  if (dims.empty())
1562  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1563 
1564  int64_t numGatherDims = dims.size();
1565  if (numGatherDims > rank)
1566  return op->emitOpError(gatherOrScatter)
1567  << "_dims overflow " << sourceOrDest << " rank";
1568  if (indices.empty() || indices.back() != numGatherDims)
1569  return op->emitOpError(gatherOrScatter)
1570  << "_dims length must match the size of last dimension of indices";
1571  for (int64_t val : dims) {
1572  if (val < 0)
1573  return op->emitOpError(gatherOrScatter)
1574  << "_dims value must be non-negative";
1575  if (val >= rank)
1576  return op->emitOpError(gatherOrScatter)
1577  << "_dims value must be smaller than " << sourceOrDest << " rank";
1578  }
1579  for (int64_t i = 1; i < numGatherDims; ++i) {
1580  if (dims[i - 1] >= dims[i])
1581  return op->emitOpError(gatherOrScatter)
1582  << "_dims values must be strictly increasing";
1583  }
1584  return success();
1585 }
1586 
1587 LogicalResult GatherOp::verify() {
1588  int64_t sourceRank = getSourceType().getRank();
1589  ArrayRef<int64_t> gatherDims = getGatherDims();
1590  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1591  getIndicesType().getShape(), sourceRank,
1592  "gather", "source")))
1593  return failure();
1594 
1595  RankedTensorType expectedResultType = GatherOp::inferResultType(
1596  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1597  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1598  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1599  if (getResultType() != expectedResultType &&
1600  getResultType() != expectedRankReducedResultType) {
1601  return emitOpError("result type "
1602  "mismatch: "
1603  "expected ")
1604  << expectedResultType << " or its rank-reduced variant "
1605  << expectedRankReducedResultType << " (got: " << getResultType()
1606  << ")";
1607  }
1608 
1609  return success();
1610 }
1611 
1612 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1613  if (OpFoldResult reshapedSource = reshapeConstantSource(
1614  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1615  getResult().getType()))
1616  return reshapedSource;
1617  return {};
1618 }
1619 
1620 //===----------------------------------------------------------------------===//
1621 // InsertOp
1622 //===----------------------------------------------------------------------===//
1623 
1624 void InsertOp::getAsmResultNames(
1625  function_ref<void(Value, StringRef)> setNameFn) {
1626  setNameFn(getResult(), "inserted");
1627 }
1628 
1629 LogicalResult InsertOp::verify() {
1630  // Verify the # indices match if we have a ranked type.
1631  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1632  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1633  return emitOpError("incorrect number of indices");
1634  return success();
1635 }
1636 
1637 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1638  Attribute scalar = adaptor.getScalar();
1639  Attribute dest = adaptor.getDest();
1640  if (scalar && dest)
1641  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1642  if (scalar == splatDest.getSplatValue<Attribute>())
1643  return dest;
1644  return {};
1645 }
1646 
1647 //===----------------------------------------------------------------------===//
1648 // GenerateOp
1649 //===----------------------------------------------------------------------===//
1650 
1651 void GenerateOp::getAsmResultNames(
1652  function_ref<void(Value, StringRef)> setNameFn) {
1653  setNameFn(getResult(), "generated");
1654 }
1655 
1656 LogicalResult GenerateOp::reifyResultShapes(
1657  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1658  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1659  int idx = 0;
1660  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1661  if (getType().isDynamicDim(dim)) {
1662  reifiedReturnShapes[0][dim] = getOperand(idx++);
1663  } else {
1664  reifiedReturnShapes[0][dim] =
1665  builder.getIndexAttr(getType().getDimSize(dim));
1666  }
1667  }
1668  return success();
1669 }
1670 
1671 LogicalResult GenerateOp::verify() {
1672  // Ensure that the tensor type has as many dynamic dimensions as are
1673  // specified by the operands.
1674  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1675  if (getNumOperands() != resultType.getNumDynamicDims())
1676  return emitError("must have as many index operands as dynamic extents "
1677  "in the result type");
1678  return success();
1679 }
1680 
1681 LogicalResult GenerateOp::verifyRegions() {
1682  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1683  // Ensure that region arguments span the index space.
1684  if (!llvm::all_of(getBody().getArgumentTypes(),
1685  [](Type ty) { return ty.isIndex(); }))
1686  return emitError("all body arguments must be index");
1687  if (getBody().getNumArguments() != resultTy.getRank())
1688  return emitError("must have one body argument per input dimension");
1689 
1690  // Ensure that the region yields an element of the right type.
1691  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1692 
1693  if (yieldOp.getValue().getType() != resultTy.getElementType())
1694  return emitOpError(
1695  "body must be terminated with a `yield` operation of the tensor "
1696  "element type");
1697 
1698  return success();
1699 }
1700 
1701 void GenerateOp::build(
1702  OpBuilder &b, OperationState &result, Type resultTy,
1703  ValueRange dynamicExtents,
1704  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1705  build(b, result, resultTy, dynamicExtents);
1706 
1707  // Build and populate body.
1708  OpBuilder::InsertionGuard guard(b);
1709  Region *bodyRegion = result.regions.front().get();
1710  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1711  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1712  SmallVector<Location, 2> argumentLocs(rank, result.location);
1713  Block *bodyBlock =
1714  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1715  bodyBuilder(b, result.location, bodyBlock->getArguments());
1716 }
1717 
1718 namespace {
1719 
1720 /// Canonicalizes tensor.generate operations with a constant
1721 /// operand into the equivalent operation with the operand expressed in the
1722 /// result type, instead. We also insert a type cast to make sure that the
1723 /// resulting IR is still well-typed.
1724 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1726 
1727  LogicalResult matchAndRewrite(GenerateOp generateOp,
1728  PatternRewriter &rewriter) const final {
1729  SmallVector<Value> foldedDynamicSizes;
1730  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1731  generateOp.getType(), generateOp.getDynamicExtents(),
1732  foldedDynamicSizes);
1733 
1734  // Stop here if no dynamic size was promoted to static.
1735  if (foldedTensorType == generateOp.getType())
1736  return failure();
1737 
1738  auto loc = generateOp.getLoc();
1739  auto newOp =
1740  GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1741  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1742  newOp.getBody().begin());
1743  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1744  generateOp.getType(), newOp);
1745  return success();
1746  }
1747 };
1748 
1749 /// Canonicalizes the pattern of the form
1750 ///
1751 /// %tensor = tensor.generate %x {
1752 /// ^bb0(%arg0: index):
1753 /// <computation>
1754 /// yield %1 : index
1755 /// } : tensor<?xindex>
1756 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1757 ///
1758 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1759 /// tensor.generate operation has no side-effects.
1760 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1762 
1763  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1764  PatternRewriter &rewriter) const final {
1765  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1766  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1767  return failure();
1768 
1769  IRMapping mapping;
1770  Block *body = &tensorFromElements.getBody().front();
1771  mapping.map(body->getArguments(), extract.getIndices());
1772  for (auto &op : body->without_terminator())
1773  rewriter.clone(op, mapping);
1774 
1775  auto yield = cast<YieldOp>(body->getTerminator());
1776 
1777  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1778  return success();
1779  }
1780 };
1781 
1782 } // namespace
1783 
1784 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785  MLIRContext *context) {
1786  // TODO: Move extract pattern to tensor::ExtractOp.
1787  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1788 }
1789 
1790 //===----------------------------------------------------------------------===//
1791 // RankOp
1792 //===----------------------------------------------------------------------===//
1793 
1794 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1795  setNameFn(getResult(), "rank");
1796 }
1797 
1798 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1799  // Constant fold rank when the rank of the operand is known.
1800  auto type = getOperand().getType();
1801  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802  if (shapedType && shapedType.hasRank())
1803  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1804  return IntegerAttr();
1805 }
1806 
1807 //===----------------------------------------------------------------------===//
1808 // ReshapeOp
1809 //===----------------------------------------------------------------------===//
1810 
1811 void ReshapeOp::getAsmResultNames(
1812  function_ref<void(Value, StringRef)> setNameFn) {
1813  setNameFn(getResult(), "reshape");
1814 }
1815 
1816 static int64_t getNumElements(ShapedType type) {
1817  int64_t numElements = 1;
1818  for (auto dim : type.getShape())
1819  numElements *= dim;
1820  return numElements;
1821 }
1822 
1823 LogicalResult ReshapeOp::verify() {
1824  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1825  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1826 
1827  if (operandType.getElementType() != resultType.getElementType())
1828  return emitOpError("element types of source and destination tensor "
1829  "types should be the same");
1830 
1831  int64_t shapeSize =
1832  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1833  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1834  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1835 
1836  if (resultRankedType) {
1837  if (operandRankedType && resultRankedType.hasStaticShape() &&
1838  operandRankedType.hasStaticShape()) {
1839  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1840  return emitOpError("source and destination tensor should have the "
1841  "same number of elements");
1842  }
1843  if (ShapedType::isDynamic(shapeSize))
1844  return emitOpError("cannot use shape operand with dynamic length to "
1845  "reshape to statically-ranked tensor type");
1846  if (shapeSize != resultRankedType.getRank())
1847  return emitOpError(
1848  "length of shape operand differs from the result's tensor rank");
1849  }
1850  return success();
1851 }
1852 
1853 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1854  if (OpFoldResult reshapedSource = reshapeConstantSource(
1855  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1856  getResult().getType()))
1857  return reshapedSource;
1858 
1859  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1860  // producer's input instead as the original tensor to reshape. This could
1861  // render such producer dead code.
1862  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1863  getSourceMutable().assign(reshapeOpProducer.getSource());
1864  return getResult();
1865  }
1866 
1867  auto source = getSource();
1868  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1869  auto resultTy = dyn_cast<RankedTensorType>(getType());
1870  if (!sourceTy || !resultTy || sourceTy != resultTy)
1871  return {};
1872 
1873  // If the source and result are both 0D or 1D tensors and have the same type,
1874  // the reshape has no effect, even if the tensor is dynamically shaped.
1875  if (sourceTy.getRank() <= 1)
1876  return source;
1877 
1878  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1879  auto elements = fromElements.getElements();
1880  bool dynamicNoop =
1881  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1882  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1883  auto element = elements[id];
1884 
1885  if (auto cst = getConstantIntValue(element)) {
1886  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1887  continue;
1888  }
1889 
1890  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1891  dynamicNoop &= dimOp.getSource() == source;
1892 
1893  auto cst = getConstantIntValue(dimOp.getIndex());
1894  dynamicNoop &=
1895  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1896  continue;
1897  }
1898 
1899  dynamicNoop = false;
1900  break;
1901  }
1902 
1903  if (dynamicNoop)
1904  return source;
1905  }
1906 
1907  return {};
1908 }
1909 
1910 //===----------------------------------------------------------------------===//
1911 // Reassociative reshape ops
1912 //===----------------------------------------------------------------------===//
1913 
1914 void CollapseShapeOp::getAsmResultNames(
1915  function_ref<void(Value, StringRef)> setNameFn) {
1916  setNameFn(getResult(), "collapsed");
1917 }
1918 
1919 void ExpandShapeOp::getAsmResultNames(
1920  function_ref<void(Value, StringRef)> setNameFn) {
1921  setNameFn(getResult(), "expanded");
1922 }
1923 
1924 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1925  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1926  "invalid resultDim");
1927  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1928  if (llvm::is_contained(it.value(), resultDim))
1929  return it.index();
1930  llvm_unreachable("could not find reassociation group");
1931 }
1932 
1933 FailureOr<SmallVector<OpFoldResult>>
1934 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1935  RankedTensorType expandedType,
1936  ArrayRef<ReassociationIndices> reassociation,
1937  ArrayRef<OpFoldResult> inputShape) {
1938  std::optional<SmallVector<OpFoldResult>> outputShape =
1939  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1940  inputShape);
1941  if (!outputShape)
1942  return failure();
1943  return *outputShape;
1944 }
1945 
1946 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1947  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1948 }
1949 
1950 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1951  Type resultType, Value src,
1952  ArrayRef<ReassociationIndices> reassociation,
1953  ArrayRef<OpFoldResult> outputShape) {
1954  auto [staticOutputShape, dynamicOutputShape] =
1956  build(builder, result, cast<RankedTensorType>(resultType), src,
1957  getReassociationIndicesAttribute(builder, reassociation),
1958  dynamicOutputShape, staticOutputShape);
1959 }
1960 
1961 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1962  Type resultType, Value src,
1963  ArrayRef<ReassociationIndices> reassociation) {
1964  SmallVector<OpFoldResult> inputShape =
1965  getMixedSizes(builder, result.location, src);
1966  auto tensorResultTy = cast<RankedTensorType>(resultType);
1967  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1968  builder, result.location, tensorResultTy, reassociation, inputShape);
1969  SmallVector<OpFoldResult> outputShapeOrEmpty;
1970  if (succeeded(outputShape)) {
1971  outputShapeOrEmpty = *outputShape;
1972  }
1973  build(builder, result, tensorResultTy, src, reassociation,
1974  outputShapeOrEmpty);
1975 }
1976 
1977 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1978  return getSymbolLessAffineMaps(getReassociationExprs());
1979 }
1980 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1982  getReassociationIndices());
1983 }
1984 
1985 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1986  return getSymbolLessAffineMaps(getReassociationExprs());
1987 }
1988 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1990  getReassociationIndices());
1991 }
1992 
1993 RankedTensorType CollapseShapeOp::inferCollapsedType(
1994  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1995  return inferCollapsedType(
1997  type.getContext(), reassociation)));
1998 }
1999 
2000 /// Compute the RankedTensorType obtained by applying `reassociation` to
2001 /// `type`.
2002 RankedTensorType
2003 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2004  ArrayRef<AffineMap> reassociation) {
2005  auto shape = type.getShape();
2006  SmallVector<int64_t, 4> newShape;
2007  newShape.reserve(reassociation.size());
2008 
2009  // Use the fact that reassociation is valid to simplify the logic: only use
2010  // each map's rank.
2011  assert(isReassociationValid(reassociation) && "invalid reassociation");
2012  unsigned currentDim = 0;
2013  for (AffineMap m : reassociation) {
2014  unsigned dim = m.getNumResults();
2015  auto band = shape.slice(currentDim, dim);
2016  int64_t size = 1;
2017  if (llvm::is_contained(band, ShapedType::kDynamic))
2018  size = ShapedType::kDynamic;
2019  else
2020  for (unsigned d = 0; d < dim; ++d)
2021  size *= shape[currentDim + d];
2022  newShape.push_back(size);
2023  currentDim += dim;
2024  }
2025 
2026  return RankedTensorType::get(newShape, type.getElementType());
2027 }
2028 
2029 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2030  ArrayRef<ReassociationIndices> reassociation,
2031  ArrayRef<NamedAttribute> attrs) {
2032  auto resultType = inferCollapsedType(
2033  llvm::cast<RankedTensorType>(src.getType()),
2035  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2036  result.addAttribute(getReassociationAttrStrName(),
2037  getReassociationIndicesAttribute(b, reassociation));
2038  build(b, result, resultType, src, attrs);
2039 }
2040 
2041 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2042  TensorReshapeOp, ExpandShapeOp>::value>
2043 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2044  RankedTensorType expandedType,
2045  RankedTensorType collapsedType) {
2046  if (failed(
2047  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2048  return failure();
2049 
2050  auto maps = op.getReassociationMaps();
2051  RankedTensorType expectedType =
2052  CollapseShapeOp::inferCollapsedType(expandedType, maps);
2053  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2054  return op.emitOpError("expected collapsed type to be ")
2055  << expectedType << ", but got " << collapsedType;
2056  return success();
2057 }
2058 
2059 LogicalResult ExpandShapeOp::verify() {
2060  auto srcType = getSrcType();
2061  auto resultType = getResultType();
2062 
2063  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2064  return emitOpError("expected number of static shape dims to be equal to "
2065  "the output rank (")
2066  << resultType.getRank() << ") but found "
2067  << getStaticOutputShape().size() << " inputs instead";
2068 
2069  if ((int64_t)getOutputShape().size() !=
2070  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2071  return emitOpError("mismatch in dynamic dims in output_shape and "
2072  "static_output_shape: static_output_shape has ")
2073  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2074  << " dynamic dims while output_shape has " << getOutputShape().size()
2075  << " values";
2076 
2077  return verifyTensorReshapeOp(*this, resultType, srcType);
2078 }
2079 
2080 LogicalResult CollapseShapeOp::verify() {
2081  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2082 }
2083 
2084 namespace {
2085 /// Reshape of a splat constant can be replaced with a constant of the result
2086 /// type.
2087 template <typename TensorReshapeOp>
2088 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2090  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2091  PatternRewriter &rewriter) const override {
2092  DenseElementsAttr attr;
2093  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2094  return failure();
2095  if (!attr || !attr.isSplat())
2096  return failure();
2098  reshapeOp.getResultType(), attr.getRawData());
2099  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2100  return success();
2101  }
2102 };
2103 
2104 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2105 template <typename TensorReshapeOp>
2106 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2107 public:
2109 
2110  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2111  PatternRewriter &rewriter) const override {
2112  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2113  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2114  return failure();
2115 
2116  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2117  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2118  return success();
2119  }
2120 };
2121 
2122 /// Reshape of a FromElements can be replaced with a FromElements of the
2123 /// result type
2124 template <typename TensorReshapeOp>
2125 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2127  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2128  PatternRewriter &rewriter) const override {
2129  auto fromElements =
2130  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2131  if (!fromElements)
2132  return failure();
2133 
2134  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2135 
2136  if (!shapedTy.hasStaticShape())
2137  return failure();
2138 
2139  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2140  fromElements.getElements());
2141  return success();
2142  }
2143 };
2144 
2145 // Fold CastOp into CollapseShapeOp when adding static information.
2146 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2148 
2149  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2150  PatternRewriter &rewriter) const override {
2151  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2152  if (!tensor::canFoldIntoConsumerOp(castOp))
2153  return failure();
2154 
2155  RankedTensorType srcType =
2156  llvm::cast<RankedTensorType>(castOp.getSource().getType());
2157  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2158  srcType, collapseShapeOp.getReassociationMaps());
2159 
2160  if (newResultType == collapseShapeOp.getResultType()) {
2161  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2162  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2163  });
2164  } else {
2165  auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2166  newResultType, castOp.getSource(),
2167  collapseShapeOp.getReassociation());
2168  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2169  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2170  }
2171  return success();
2172  }
2173 };
2174 
2175 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2176 /// matching constant output_shape operands of the expand. This makes the
2177 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2178 /// propagated further.
2179 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2181 
2182  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2183  PatternRewriter &rewriter) const override {
2184  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2185  if (!canFoldIntoConsumerOp(castOp))
2186  return failure();
2187 
2188  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2190  expandOp.getReassociationIndices();
2191 
2192  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2193  SmallVector<Value> dynamicOutputShape;
2194  auto outputIt = expandOp.getOutputShape().begin();
2195 
2196  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2197  for (uint64_t outDim : innerReassoc) {
2198  if (ShapedType::isStatic(newOutputShape[outDim]))
2199  continue;
2200 
2201  // If the cast's src type is dynamic, don't infer any of the
2202  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2203  // least one of the expanded dimensions to be dynamic if the input is
2204  // dynamic.
2205  Value val = *outputIt;
2206  ++outputIt;
2207  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2208  dynamicOutputShape.push_back(val);
2209  continue;
2210  }
2211 
2212  APInt cst;
2213  if (matchPattern(val, m_ConstantInt(&cst))) {
2214  newOutputShape[outDim] = cst.getSExtValue();
2215  } else {
2216  dynamicOutputShape.push_back(val);
2217  }
2218  }
2219  }
2220 
2221  // Couldn't match any values, nothing to change
2222  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2223  return failure();
2224 
2225  // Calculate the input shape from the output
2226  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2227  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2228  for (auto outDim : reassoc[inDim]) {
2229  auto ofr = newOutputShape[outDim];
2230  if (ShapedType::isDynamic(ofr)) {
2231  newInputShape[inDim] = ShapedType::kDynamic;
2232  break;
2233  }
2234  newInputShape[inDim] *= ofr;
2235  }
2236  }
2237 
2238  SmallVector<OpFoldResult> outputOfr =
2239  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2240  auto inputType = RankedTensorType::get(
2241  newInputShape, expandOp.getSrcType().getElementType());
2242  auto outputType = RankedTensorType::get(
2243  newOutputShape, expandOp.getSrcType().getElementType());
2244  auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2245  expandOp.getSrc());
2246  auto newExpand = ExpandShapeOp::create(
2247  rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2248  expandOp.getReassociationIndices(), outputOfr);
2249  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2250  newExpand.getResult());
2251  return success();
2252  }
2253 };
2254 } // namespace
2255 
2256 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2257  MLIRContext *context) {
2258  results.add<
2261  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2262  FoldReshapeWithSplat<ExpandShapeOp>,
2263  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2264 }
2265 
2266 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2267  MLIRContext *context) {
2268  results.add<
2270  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2271  tensor::DimOp, RankedTensorType>,
2272  FoldReshapeWithConstant<CollapseShapeOp>,
2273  FoldReshapeWithSplat<CollapseShapeOp>,
2274  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2275  context);
2276 }
2277 
2278 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2279  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2280  adaptor.getOperands());
2281 }
2282 
2283 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2284  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2285  adaptor.getOperands());
2286 }
2287 
2288 //===----------------------------------------------------------------------===//
2289 // ExtractSliceOp
2290 //===----------------------------------------------------------------------===//
2291 
2292 void ExtractSliceOp::getAsmResultNames(
2293  function_ref<void(Value, StringRef)> setNameFn) {
2294  setNameFn(getResult(), "extracted_slice");
2295 }
2296 
2297 /// An extract_slice result type can be inferred, when it is not
2298 /// rank-reduced, from the source type and the static representation of
2299 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2300 RankedTensorType ExtractSliceOp::inferResultType(
2301  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2302  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2303  // An extract_slice op may specify only a leading subset of offset/sizes/
2304  // strides in which case we complete with offset=0, sizes from memref type
2305  // and strides=1.
2306  assert(static_cast<int64_t>(staticSizes.size()) ==
2307  sourceTensorType.getRank() &&
2308  "unexpected staticSizes not equal to rank of source");
2309  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2310  sourceTensorType.getEncoding());
2311 }
2312 
2313 RankedTensorType ExtractSliceOp::inferResultType(
2314  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2316  SmallVector<int64_t> staticSizes;
2317  std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2318  assert(static_cast<int64_t>(staticSizes.size()) ==
2319  sourceTensorType.getRank() &&
2320  "unexpected staticSizes not equal to rank of source");
2321  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2322  sourceTensorType.getEncoding());
2323 }
2324 
2325 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2326 /// number of sizes), drop as many size 1 as needed to produce an inferred
2327 /// type with the desired rank.
2328 ///
2329 /// Note that there may be multiple ways to compute this rank-reduced type:
2330 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2331 ///
2332 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2333 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2334  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2335  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2336  ArrayRef<int64_t> strides) {
2337  // Type inferred in the absence of rank-reducing behavior.
2338  auto inferredType = llvm::cast<RankedTensorType>(
2339  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2340  int rankDiff = inferredType.getRank() - desiredResultRank;
2341  if (rankDiff > 0) {
2342  auto shape = inferredType.getShape();
2343  llvm::SmallBitVector dimsToProject =
2344  getPositionsOfShapeOne(rankDiff, shape);
2345  SmallVector<int64_t> projectedShape;
2346  // Best effort rank-reducing: drop 1s in order.
2347  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2348  if (!dimsToProject.test(pos))
2349  projectedShape.push_back(shape[pos]);
2350  inferredType =
2351  RankedTensorType::get(projectedShape, inferredType.getElementType());
2352  }
2353  return inferredType;
2354 }
2355 
2356 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2357  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2359  ArrayRef<OpFoldResult> strides) {
2360  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2361  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2362  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2363  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2364  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2365  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2366  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2367  staticStrides);
2368 }
2369 
2370 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2371 /// result type. If the type passed is nullptr, it is inferred.
2372 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2373  RankedTensorType resultType, Value source,
2374  ArrayRef<OpFoldResult> offsets,
2375  ArrayRef<OpFoldResult> sizes,
2376  ArrayRef<OpFoldResult> strides,
2377  ArrayRef<NamedAttribute> attrs) {
2378  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2379  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2380  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2381  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2382  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2383  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2384  // Structuring implementation this way avoids duplication between builders.
2385  if (!resultType) {
2386  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2387  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2388  }
2389  result.addAttributes(attrs);
2390  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2391  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2392  b.getDenseI64ArrayAttr(staticSizes),
2393  b.getDenseI64ArrayAttr(staticStrides));
2394 }
2395 
2396 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2397 /// result type.
2398 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2399  ArrayRef<OpFoldResult> offsets,
2400  ArrayRef<OpFoldResult> sizes,
2401  ArrayRef<OpFoldResult> strides,
2402  ArrayRef<NamedAttribute> attrs) {
2403  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2404 }
2405 
2406 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2407 /// a Range vector.
2408 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2409  ArrayRef<Range> ranges,
2410  ArrayRef<NamedAttribute> attrs) {
2411  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2412  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2413 }
2414 
2415 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2416 /// the type passed is nullptr, it is inferred.
2417 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2418  RankedTensorType resultType, Value source,
2419  ValueRange offsets, ValueRange sizes,
2420  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2421  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2422  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2423  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2424  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2425  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2426  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2427  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2428 }
2429 
2430 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2431 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2432  ValueRange offsets, ValueRange sizes,
2433  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2434  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2435 }
2436 
2437 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2438  Operation *op,
2439  RankedTensorType expectedType) {
2440  switch (result) {
2442  return success();
2444  return op->emitError("expected rank to be smaller or equal to ")
2445  << "the other rank. ";
2447  return op->emitError("expected type to be ")
2448  << expectedType << " or a rank-reduced version. (size mismatch) ";
2450  return op->emitError("expected element type to be ")
2451  << expectedType.getElementType();
2452  default:
2453  llvm_unreachable("unexpected extract_slice op verification result");
2454  }
2455 }
2456 
2457 /// Verifier for ExtractSliceOp.
2458 LogicalResult ExtractSliceOp::verify() {
2459  RankedTensorType sourceType = getSourceType();
2460 
2461  // Verify result type against inferred type.
2462  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2463  sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2464  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2465  if (result != SliceVerificationResult::Success)
2466  return produceSliceErrorMsg(result, *this, expectedType);
2467 
2468  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2469  // to the source tensor.
2471  sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2472  getStaticStrides(), /*generateErrorMessage=*/true);
2473  if (!boundsResult.isValid)
2474  return getOperation()->emitError(boundsResult.errorMessage);
2475 
2476  return success();
2477 }
2478 
2479 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2481 }
2482 
2483 FailureOr<Value>
2484 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2485  ArrayRef<int64_t> desiredShape) {
2486  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2487  assert(sourceTensorType && "not a ranked tensor type");
2488  auto sourceShape = sourceTensorType.getShape();
2489  if (sourceShape.equals(desiredShape))
2490  return value;
2491  auto maybeRankReductionMask =
2492  mlir::computeRankReductionMask(sourceShape, desiredShape);
2493  if (!maybeRankReductionMask)
2494  return failure();
2496  b, loc, value,
2497  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2498 }
2499 
2500 LogicalResult ExtractSliceOp::reifyResultShapes(
2501  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2502  reifiedReturnShapes.resize(1);
2503  reifiedReturnShapes[0].reserve(getType().getRank());
2505  llvm::SmallBitVector droppedDims = getDroppedDims();
2506  for (const auto &size : enumerate(mixedSizes)) {
2507  if (droppedDims.test(size.index()))
2508  continue;
2509  reifiedReturnShapes[0].push_back(size.value());
2510  }
2511  return success();
2512 }
2513 
2514 namespace {
2515 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2516 /// This essentially pushes memref_cast past its consuming slice when
2517 /// `canFoldIntoConsumerOp` is true.
2518 ///
2519 /// Example:
2520 /// ```
2521 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2522 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2523 /// tensor<3x4xf32>
2524 /// ```
2525 /// is rewritten into:
2526 /// ```
2527 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2528 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2529 /// ```
2530 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2531 public:
2533 
2534  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2535  PatternRewriter &rewriter) const override {
2536  // Any constant operand, just return to let the constant folder kick in.
2537  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2538  return matchPattern(operand, matchConstantIndex());
2539  }))
2540  return failure();
2541 
2542  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2543  if (!castOp)
2544  return failure();
2545 
2546  if (!canFoldIntoConsumerOp(castOp))
2547  return failure();
2548 
2549  // Pattern does not apply if the produced op would not verify.
2551  cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2552  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2553  sliceOp.getStaticStrides());
2554  if (!sliceResult.isValid)
2555  return failure();
2556 
2557  // Create folded extract.
2558  Location loc = sliceOp.getLoc();
2559  Value newResult = ExtractSliceOp::create(
2560  rewriter, loc, sliceOp.getType(), castOp.getSource(),
2561  sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2562  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2563  sliceOp.getStaticStrides());
2564  rewriter.replaceOp(sliceOp, newResult);
2565  return success();
2566  }
2567 };
2568 
2569 /// Slice elements from `values` into `outValues`. `counts` represents the
2570 /// numbers of elements to stride in the original values for each dimension.
2571 /// The output values can be used to construct a DenseElementsAttr.
2572 template <typename IterTy, typename ElemTy>
2573 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2574  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2575  ArrayRef<int64_t> strides,
2576  llvm::SmallVectorImpl<ElemTy> *outValues) {
2577  assert(offsets.size() == sizes.size());
2578  assert(offsets.size() == strides.size());
2579  if (offsets.empty())
2580  return;
2581 
2582  int64_t offset = offsets.front();
2583  int64_t size = sizes.front();
2584  int64_t stride = strides.front();
2585  if (offsets.size() == 1) {
2586  for (int64_t i = 0; i < size; ++i, offset += stride)
2587  outValues->push_back(*(values + offset));
2588 
2589  return;
2590  }
2591 
2592  for (int64_t i = 0; i < size; ++i, offset += stride) {
2593  auto begin = values + offset * counts.front();
2594  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2595  offsets.drop_front(), sizes.drop_front(),
2596  strides.drop_front(), outValues);
2597  }
2598 }
2599 
2600 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2601 /// folded operation might introduce more constant data; Users can control
2602 /// their heuristics by the control function.
2603 class ConstantOpExtractSliceFolder final
2604  : public OpRewritePattern<ExtractSliceOp> {
2605 public:
2607 
2608  ConstantOpExtractSliceFolder(MLIRContext *context,
2610  : OpRewritePattern<ExtractSliceOp>(context),
2611  controlFn(std::move(controlFn)) {}
2612 
2613  LogicalResult matchAndRewrite(ExtractSliceOp op,
2614  PatternRewriter &rewriter) const override {
2615  DenseElementsAttr attr;
2616  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2617  return failure();
2618 
2619  // A constant splat is handled by fold().
2620  if (attr.isSplat())
2621  return failure();
2622 
2623  // Dynamic result shape is not supported.
2624  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2625  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2626  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2627  return failure();
2628 
2629  // Customized control over the folding.
2630  if (!controlFn(op))
2631  return failure();
2632 
2633  int64_t count = sourceType.getNumElements();
2634  if (count == 0)
2635  return failure();
2636 
2637  // Check if there are any dynamic parts, which are not supported.
2638  auto offsets = op.getStaticOffsets();
2639  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2640  return failure();
2641  auto sizes = op.getStaticSizes();
2642  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2643  return failure();
2644  auto strides = op.getStaticStrides();
2645  if (llvm::is_contained(strides, ShapedType::kDynamic))
2646  return failure();
2647 
2648  // Compute the stride for each dimension.
2649  SmallVector<int64_t> counts;
2650  ArrayRef<int64_t> shape = sourceType.getShape();
2651  counts.reserve(shape.size());
2652  for (int64_t v : shape) {
2653  count = count / v;
2654  counts.push_back(count);
2655  }
2656 
2657  // New attribute constructed by the sliced values.
2658  DenseElementsAttr newAttr;
2659 
2660  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2661  SmallVector<APInt> outValues;
2662  outValues.reserve(sourceType.getNumElements());
2663  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2664  elems.begin(), counts, offsets, sizes, strides, &outValues);
2665  newAttr = DenseElementsAttr::get(resultType, outValues);
2666  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2667  SmallVector<APFloat> outValues;
2668  outValues.reserve(sourceType.getNumElements());
2669  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2670  elems.begin(), counts, offsets, sizes, strides, &outValues);
2671  newAttr = DenseElementsAttr::get(resultType, outValues);
2672  }
2673 
2674  if (newAttr) {
2675  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2676  return success();
2677  }
2678 
2679  return failure();
2680  }
2681 
2682 private:
2683  /// This additionally controls whether the fold happens or not. Users can
2684  /// impose their heuristics in the function.
2686 };
2687 
2688 } // namespace
2689 
2692  const ControlConstantExtractSliceFusionFn &controlFn) {
2693  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2694 }
2695 
2696 /// Return the canonical type of the result of an extract_slice op.
2698  RankedTensorType operator()(ExtractSliceOp op,
2699  ArrayRef<OpFoldResult> mixedOffsets,
2700  ArrayRef<OpFoldResult> mixedSizes,
2701  ArrayRef<OpFoldResult> mixedStrides) {
2702  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2703  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2704  mixedStrides);
2705  }
2706 };
2707 
2708 /// A canonicalizer wrapper to replace ExtractSliceOps.
2710  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2711  ExtractSliceOp newOp) {
2712  Value replacement = newOp.getResult();
2713  if (replacement.getType() != op.getType())
2714  replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2715  replacement);
2716  rewriter.replaceOp(op, replacement);
2717  }
2718 };
2719 
2720 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2721  MLIRContext *context) {
2722  results.add<
2725  ExtractSliceOpCastFolder>(context);
2726 }
2727 
2728 //
2729 static LogicalResult
2730 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2731  ShapedType shapedType) {
2732  OpBuilder b(op.getContext());
2733  for (OpFoldResult ofr : op.getMixedOffsets())
2734  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2735  return failure();
2736  // Rank-reducing noops only need to inspect the leading dimensions:
2737  // llvm::zip is appropriate.
2738  auto shape = shapedType.getShape();
2739  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2740  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2741  return failure();
2742  for (OpFoldResult ofr : op.getMixedStrides())
2743  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2744  return failure();
2745  return success();
2746 }
2747 
2748 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2749 /// slice, we can return the InsertSliceOp's source directly.
2750 // TODO: This only checks the immediate producer; extend to go up the
2751 // insert/extract chain if the slices are disjoint.
2752 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2753  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2754 
2755  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2756  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2757  insertOp.isSameAs(extractOp, isSame))
2758  return insertOp.getSource();
2759 
2760  return {};
2761 }
2762 
2763 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2764  if (OpFoldResult reshapedSource = reshapeConstantSource(
2765  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2766  getResult().getType()))
2767  return reshapedSource;
2768  if (getSourceType() == getType() &&
2770  return this->getSource();
2771  if (Value slice = foldExtractAfterInsertSlice(*this))
2772  return slice;
2773 
2774  return OpFoldResult();
2775 }
2776 
2778  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2779  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2780  unsigned rank = rankedTensorType.getRank();
2781  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2782  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2783  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2784  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2785  offsets, sizes, strides);
2786 }
2787 
2788 //===----------------------------------------------------------------------===//
2789 // InsertSliceOp
2790 //===----------------------------------------------------------------------===//
2791 
2792 void InsertSliceOp::getAsmResultNames(
2793  function_ref<void(Value, StringRef)> setNameFn) {
2794  setNameFn(getResult(), "inserted_slice");
2795 }
2796 
2797 // Build a InsertSliceOp with mixed static and dynamic entries.
2798 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2799  Value dest, ArrayRef<OpFoldResult> offsets,
2800  ArrayRef<OpFoldResult> sizes,
2801  ArrayRef<OpFoldResult> strides,
2802  ArrayRef<NamedAttribute> attrs) {
2803  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2804  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2805  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2806  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2807  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2808  result.addAttributes(attrs);
2809  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2810  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2811  b.getDenseI64ArrayAttr(staticSizes),
2812  b.getDenseI64ArrayAttr(staticStrides));
2813 }
2814 
2815 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2816 /// Range vector.
2817 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2818  Value dest, ArrayRef<Range> ranges,
2819  ArrayRef<NamedAttribute> attrs) {
2820  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2821  build(b, result, source, dest, offsets, sizes, strides, attrs);
2822 }
2823 
2824 // Build a InsertSliceOp with dynamic entries.
2825 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2826  Value dest, ValueRange offsets, ValueRange sizes,
2827  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2828  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2829  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2830  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2831  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2832  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2833  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2834  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2835 }
2836 
2837 /// Rank-reducing type verification for both InsertSliceOp and
2838 /// ParallelInsertSliceOp.
2840  RankedTensorType srcType, RankedTensorType dstType,
2841  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2842  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2843  // insert_slice is the inverse of extract_slice, use the same type
2844  // inference.
2845  RankedTensorType expected = ExtractSliceOp::inferResultType(
2846  dstType, staticOffsets, staticSizes, staticStrides);
2847  if (expectedType)
2848  *expectedType = expected;
2849  return isRankReducedType(expected, srcType);
2850 }
2851 
2852 /// Verifier for InsertSliceOp.
2853 LogicalResult InsertSliceOp::verify() {
2854  // Verify result type against inferred type.
2855  RankedTensorType expectedType;
2856  SliceVerificationResult result =
2857  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2858  getStaticSizes(), getStaticStrides(), &expectedType);
2859  if (result != SliceVerificationResult::Success)
2860  return produceSliceErrorMsg(result, *this, expectedType);
2861 
2862  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2863  // to the destination tensor.
2865  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2866  getStaticStrides(), /*generateErrorMessage=*/true);
2867  if (!boundsResult.isValid)
2868  return getOperation()->emitError(boundsResult.errorMessage);
2869 
2870  return success();
2871 }
2872 
2873 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2874 /// can mutate the second InsertSliceOp's destination to the first one's.
2875 ///
2876 /// Example:
2877 ///
2878 /// ```mlir
2879 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2880 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2881 /// ```
2882 ///
2883 /// folds into:
2884 ///
2885 /// ```mlir
2886 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2887 /// ```
2888 ///
2889 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2890 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2891  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2892 
2893  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2894  if (!prevInsertOp ||
2895  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2896  !prevInsertOp.isSameAs(insertOp, isSame))
2897  return failure();
2898 
2899  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2900  return success();
2901 }
2902 
2903 /// Folds round-trip extract/insert slice op pairs.
2904 /// Example:
2905 /// ```mlir
2906 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2907 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2908 /// ```
2909 /// can be folded into %val.
2910 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2911  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2912 
2913  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2914  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2915  !extractOp.isSameAs(insertOp, isSame))
2916  return nullptr;
2917 
2918  return extractOp.getSource();
2919 }
2920 
2921 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2922  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2923  getSourceType() == getType() &&
2925  return this->getSource();
2926  if (succeeded(foldInsertAfterInsertSlice(*this)))
2927  return getResult();
2928  if (auto result = foldInsertAfterExtractSlice(*this))
2929  return result;
2930  if (llvm::any_of(getMixedSizes(), isZeroInteger))
2931  return getDest();
2932  return OpFoldResult();
2933 }
2934 
2935 LogicalResult InsertSliceOp::reifyResultShapes(
2936  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2937  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2938  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2939  return success();
2940 }
2941 
2942 namespace {
2943 /// Pattern to rewrite a insert_slice op with constant arguments.
2944 ///
2945 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2946 template <typename InsertOpTy>
2947 class InsertSliceOpConstantArgumentFolder final
2948  : public OpRewritePattern<InsertOpTy> {
2949 public:
2951 
2952  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2953  PatternRewriter &rewriter) const override {
2954  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2955  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2956  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2957 
2958  // No constant operands were folded, just return;
2959  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2960  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2961  failed(foldDynamicStrideList(mixedStrides)))
2962  return failure();
2963 
2964  // Pattern does not apply if the produced op would not verify.
2965  SliceBoundsVerificationResult sliceResult =
2966  verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2967  mixedOffsets, mixedSizes, mixedStrides);
2968  if (!sliceResult.isValid)
2969  return failure();
2970 
2971  // Create the new op in canonical form.
2972  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2973  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2974  mixedOffsets, mixedSizes, mixedStrides);
2975  Value toInsert = insertSliceOp.getSource();
2976  if (sourceType != insertSliceOp.getSourceType()) {
2977  OpBuilder::InsertionGuard g(rewriter);
2978  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2979  // is that the insertion point is just before the InParallelOp in
2980  // the parallel case.
2981  if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2982  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2983  toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2984  sourceType, toInsert);
2985  }
2986  rewriter.replaceOpWithNewOp<InsertOpTy>(
2987  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2988  mixedSizes, mixedStrides);
2989  return success();
2990  }
2991 };
2992 
2993 /// Fold tensor_casts with insert_slice operations. If the source or
2994 /// destination tensor is a tensor_cast that removes static type information,
2995 /// the cast is folded into the insert_slice operation. E.g.:
2996 ///
2997 /// ```mlir
2998 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2999 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3000 /// ```
3001 ///
3002 /// folds into:
3003 ///
3004 /// ```mlir
3005 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3006 /// ```
3007 ///
3008 /// Note: When folding a cast on the destination tensor, the result of the
3009 /// insert_slice operation is casted to ensure that the type of the result did
3010 /// not change.
3011 ///
3012 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3013 template <typename InsertOpTy>
3014 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3016 
3017  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3018  PatternRewriter &rewriter) const override {
3019  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3020  return matchPattern(operand, matchConstantIndex());
3021  }))
3022  return failure();
3023 
3024  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3025  auto castOp = v.getDefiningOp<tensor::CastOp>();
3026  if (!castOp || !canFoldIntoConsumerOp(castOp))
3027  return std::nullopt;
3028  return castOp.getSource();
3029  };
3030  std::optional<Value> sourceCastSource =
3031  getSourceOfCastOp(insertSliceOp.getSource());
3032  std::optional<Value> destCastSource =
3033  getSourceOfCastOp(insertSliceOp.getDest());
3034  if (!sourceCastSource && !destCastSource)
3035  return failure();
3036 
3037  auto src =
3038  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3039  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3040  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3041  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3042  if (!srcType || !dstType)
3043  return failure();
3044 
3045  // The tensor.cast source could have additional static information not seen
3046  // in the insert slice op static sizes, so we ignore dynamic dims when
3047  // computing the rank reduction mask.
3048  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3049  auto rankReductionMask = computeRankReductionMask(
3050  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3051  if (!rankReductionMask.has_value())
3052  return failure();
3053  // Replace dimensions in the insert slice op with corresponding static dims
3054  // from the cast source type. If the insert slice sizes have static dims
3055  // that are not static in the tensor.cast source (i.e., when the cast op
3056  // casts a dynamic dim to static), the dim should not be replaced, and the
3057  // pattern will fail later in `verifyInsertSliceOp`.
3058  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3059  int64_t rankReducedIdx = 0;
3060  for (auto [idx, size] : enumerate(staticSizes)) {
3061  if (!rankReductionMask.value().contains(idx) &&
3062  !srcType.isDynamicDim(rankReducedIdx)) {
3063  mixedSizes[idx] = getAsIndexOpFoldResult(
3064  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3065  size = srcType.getDimSize(rankReducedIdx++);
3066  }
3067  }
3068 
3069  // Pattern does not apply if the produced op would not verify.
3070  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3071  staticSizes, insertSliceOp.getStaticStrides()) !=
3073  return failure();
3074  SliceBoundsVerificationResult sliceResult =
3075  verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3076  mixedSizes, insertSliceOp.getMixedStrides());
3077  if (!sliceResult.isValid)
3078  return failure();
3079 
3080  Operation *replacement =
3081  InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3082  insertSliceOp.getMixedOffsets(), mixedSizes,
3083  insertSliceOp.getMixedStrides());
3084 
3085  // In the parallel case there is no result and so nothing to cast.
3086  bool isParallelInsert =
3087  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089  replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3090  insertSliceOp.getDestType(),
3091  replacement->getResult(0));
3092  }
3093  rewriter.replaceOp(insertSliceOp, replacement->getResults());
3094  return success();
3095  }
3096 };
3097 
3098 /// If additional static type information can be deduced from a insert_slice's
3099 /// size operands, insert an explicit cast of the op's source operand. This
3100 /// enables other canonicalization patterns that are matching for tensor_cast
3101 /// ops such as `ForOpTensorCastFolder` in SCF.
3102 ///
3103 /// Example:
3104 ///
3105 /// ```mlir
3106 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3107 /// : tensor<?x?xf32> into ...
3108 /// ```
3109 ///
3110 /// folds into:
3111 ///
3112 /// ```mlir
3113 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3114 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3115 /// : tensor<64x64xf32> into ...
3116 /// ```
3117 ///
3118 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3119 template <typename InsertOpTy>
3120 struct InsertSliceOpSourceCastInserter final
3121  : public OpRewritePattern<InsertOpTy> {
3123 
3124  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3125  PatternRewriter &rewriter) const override {
3126  RankedTensorType srcType = insertSliceOp.getSourceType();
3127  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3128  return failure();
3129  SmallVector<int64_t> newSrcShape(srcType.getShape());
3130  for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131  if (std::optional<int64_t> constInt =
3132  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3133  // Bail on invalid IR.
3134  if (*constInt < 0)
3135  return failure();
3136  newSrcShape[i] = *constInt;
3137  }
3138  }
3139  if (!hasValidSizesOffsets(newSrcShape))
3140  return failure();
3141 
3142  RankedTensorType newSrcType = RankedTensorType::get(
3143  newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144  if (srcType == newSrcType ||
3145  !preservesStaticInformation(srcType, newSrcType) ||
3146  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3147  return failure();
3148 
3149  // newSrcType is:
3150  // 1) Different from srcType.
3151  // 2) "More static" than srcType.
3152  // 3) Cast-compatible with srcType.
3153  // Insert the cast.
3154  OpBuilder::InsertionGuard g(rewriter);
3155  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3156  // that the insertion point is just before the InParallelOp in the
3157  // parallel case.
3158  if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3159  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3160  Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3161  newSrcType, insertSliceOp.getSource());
3162  rewriter.replaceOpWithNewOp<InsertOpTy>(
3163  insertSliceOp, cast, insertSliceOp.getDest(),
3164  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165  insertSliceOp.getMixedStrides());
3166  return success();
3167  }
3168 };
3169 } // namespace
3170 
3171 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3172  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3173 }
3174 
3175 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3176  MLIRContext *context) {
3177  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3178  InsertSliceOpCastFolder<InsertSliceOp>,
3179  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3180 }
3181 
3183  Location loc,
3184  Value tensor,
3185  Value dest) {
3186  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3187  unsigned rank = rankedTensorType.getRank();
3188  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3189  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3190  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3191  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3192  sizes, strides);
3193 }
3194 
3195 //===----------------------------------------------------------------------===//
3196 // PadOp
3197 //===----------------------------------------------------------------------===//
3198 
3199 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3200  setNameFn(getResult(), "padded");
3201 }
3202 
3203 LogicalResult PadOp::verify() {
3204  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3205  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3206  auto expectedType =
3207  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3208  if (!expectedType) {
3209  return emitError("failed to infer expectedType from sourceType ")
3210  << sourceType << ", specified resultType is " << resultType;
3211  }
3212  if (resultType.getRank() != expectedType.getRank()) {
3213  return emitError("specified type ")
3214  << resultType << " does not match the inferred type "
3215  << expectedType;
3216  }
3217  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3218  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3219  continue;
3220  if (expectedType.isDynamicDim(i))
3221  continue;
3222  return emitError("specified type ")
3223  << resultType << " does not match the inferred type "
3224  << expectedType;
3225  }
3226 
3227  return success();
3228 }
3229 
3230 LogicalResult PadOp::verifyRegions() {
3231  auto &region = getRegion();
3232  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3233  Block &block = region.front();
3234  if (block.getNumArguments() != rank)
3235  return emitError("expected the block to have ") << rank << " arguments";
3236 
3237  // Note: the number and type of yield values are checked in the YieldOp.
3238  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3239  if (!en.value().isIndex())
3240  return emitOpError("expected block argument ")
3241  << (en.index() + 1) << " to be an index";
3242  }
3243 
3244  // Ensure that the region yields an element of the right type.
3245  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3246  if (yieldOp.getValue().getType() !=
3247  llvm::cast<ShapedType>(getType()).getElementType())
3248  return emitOpError("expected yield type to match shape element type");
3249 
3250  return success();
3251 }
3252 
3253 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3254  ArrayRef<int64_t> staticLow,
3255  ArrayRef<int64_t> staticHigh,
3256  ArrayRef<int64_t> resultShape) {
3257  unsigned rank = sourceType.getRank();
3258  if (staticLow.size() != rank)
3259  return RankedTensorType();
3260  if (staticHigh.size() != rank)
3261  return RankedTensorType();
3262  if (!resultShape.empty() && resultShape.size() != rank)
3263  return RankedTensorType();
3264 
3265  SmallVector<int64_t, 4> inferredShape;
3266  for (auto i : llvm::seq<unsigned>(0, rank)) {
3267  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3268  staticHigh[i] == ShapedType::kDynamic) {
3269  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3270  : resultShape[i]);
3271  } else {
3272  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3273  assert((resultShape.empty() || size == resultShape[i] ||
3274  resultShape[i] == ShapedType::kDynamic) &&
3275  "mismatch between inferred shape and result shape");
3276  inferredShape.push_back(size);
3277  }
3278  }
3279 
3280  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3281 }
3282 
3283 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3284  Value source, ArrayRef<int64_t> staticLow,
3285  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3286  bool nofold, ArrayRef<NamedAttribute> attrs) {
3287  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3288  if (!resultType)
3289  resultType = inferResultType(sourceType, staticLow, staticHigh);
3290  result.addAttributes(attrs);
3291  build(b, result, resultType, source, low, high,
3292  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3293  nofold ? b.getUnitAttr() : UnitAttr());
3294 }
3295 
3296 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3297  Value source, ValueRange low, ValueRange high, bool nofold,
3298  ArrayRef<NamedAttribute> attrs) {
3299  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3300  unsigned rank = sourceType.getRank();
3301  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3302  build(b, result, resultType, source, staticVector, staticVector, low, high,
3303  nofold, attrs);
3304 }
3305 
3306 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3307  Value source, ArrayRef<OpFoldResult> low,
3308  ArrayRef<OpFoldResult> high, bool nofold,
3309  ArrayRef<NamedAttribute> attrs) {
3310  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3311  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3312  SmallVector<int64_t, 4> staticLow, staticHigh;
3313  // staticLow and staticHigh have full information of the padding config.
3314  // This will grow staticLow and staticHigh with 1 value. If the config is
3315  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3316  // value as well.
3317  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3318  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3319  if (!resultType) {
3320  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3321  }
3322  assert(llvm::isa<RankedTensorType>(resultType));
3323  result.addAttributes(attrs);
3324  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3325  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3326  nofold ? b.getUnitAttr() : UnitAttr());
3327 }
3328 
3329 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3330  Value source, ArrayRef<OpFoldResult> low,
3331  ArrayRef<OpFoldResult> high, Value constantPadValue,
3332  bool nofold, ArrayRef<NamedAttribute> attrs) {
3333  build(b, result, resultType, source, low, high, nofold, attrs);
3334 
3335  // Add a region and a block to yield the pad value.
3336  Region *region = result.regions[0].get();
3337  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3338  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3339  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3340 
3341  // `builder.createBlock` changes the insertion point within the block. Create
3342  // a guard to reset the insertion point of the builder after it is destroyed.
3343  OpBuilder::InsertionGuard guard(b);
3344  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3345  tensor::YieldOp::create(b, result.location, constantPadValue);
3346 }
3347 
3348 llvm::SmallBitVector PadOp::getPaddedDims() {
3349  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3350  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3351  for (const auto &en : enumerate(paddingWidths))
3352  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3353  paddedDims.set(en.index());
3354  };
3355  extractPaddedDims(getMixedLowPad());
3356  extractPaddedDims(getMixedHighPad());
3357  return paddedDims;
3358 }
3359 
3360 namespace {
3361 // Folds tensor.pad when padding is static zeros and the attribute
3362 // doesn't request otherwise.
3363 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3365 
3366  LogicalResult matchAndRewrite(PadOp padTensorOp,
3367  PatternRewriter &rewriter) const override {
3368  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3369  return failure();
3370  if (padTensorOp.getNofold())
3371  return failure();
3372  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3373  padTensorOp, padTensorOp.getResult().getType(),
3374  padTensorOp.getSource());
3375  return success();
3376  }
3377 };
3378 
3379 // Fold CastOp into PadOp when adding static information.
3380 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3382 
3383  LogicalResult matchAndRewrite(PadOp padTensorOp,
3384  PatternRewriter &rewriter) const override {
3385  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3386  if (!tensor::canFoldIntoConsumerOp(castOp))
3387  return failure();
3388 
3389  auto newResultType = PadOp::inferResultType(
3390  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3391  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3392  padTensorOp.getResultType().getShape());
3393 
3394  if (newResultType == padTensorOp.getResultType()) {
3395  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3396  padTensorOp.getSourceMutable().assign(castOp.getSource());
3397  });
3398  } else {
3399  auto newOp = PadOp::create(
3400  rewriter, padTensorOp->getLoc(), newResultType,
3401  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3402  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3403  padTensorOp.getHigh(), padTensorOp.getNofold(),
3404  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3405  IRMapping mapper;
3406  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3407 
3408  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3409  padTensorOp, padTensorOp.getResultType(), newOp);
3410  }
3411  return success();
3412  }
3413 };
3414 
3415 // Fold CastOp using the result of PadOp back into the latter if it adds
3416 // static information.
3417 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3419 
3420  LogicalResult matchAndRewrite(PadOp padTensorOp,
3421  PatternRewriter &rewriter) const override {
3422  if (!padTensorOp.getResult().hasOneUse())
3423  return failure();
3424  auto tensorCastOp =
3425  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3426  if (!tensorCastOp)
3427  return failure();
3428  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3429  tensorCastOp.getDest().getType()))
3430  return failure();
3431 
3432  auto replacementOp = PadOp::create(
3433  rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3434  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3435  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3436  padTensorOp.getHigh(), padTensorOp.getNofold(),
3437  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3438  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3439 
3440  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3441  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3442  return success();
3443  }
3444 };
3445 
3446 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3447 /// different dimensions. The pattern applies if the following preconditions
3448 /// hold:
3449 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3450 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3451 /// 3) the tensor::PadOps perform only high-padding,
3452 /// 4) the tensor::PadOps have the same constant padding value,
3453 /// 5) the tensor::PadOps do not have common padding dimensions,
3454 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3455 /// zero-offset for every dimension.
3456 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3457 /// the
3458 /// padded source dimensions.
3459 ///
3460 /// Example:
3461 ///
3462 /// ```mlir
3463 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3464 /// : tensor<64x64xf32> to tensor<?x64xf32>
3465 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3466 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3467 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3468 /// : tensor<8x64xf32> to tensor<8x?xf32>
3469 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3470 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3471 /// ```
3472 ///
3473 /// folds into:
3474 ///
3475 /// ```mlir
3476 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3477 /// : tensor<64x64xf32> to tensor<?x?xf32>
3478 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3479 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3480 /// ```
3481 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3483 
3484  LogicalResult matchAndRewrite(PadOp padOp,
3485  PatternRewriter &rewriter) const override {
3486  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3487  if (!innerSliceOp)
3488  return failure();
3489  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3490  if (!outerPadOp || outerPadOp.getNofold())
3491  return failure();
3492  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3493  if (!outerSliceOp)
3494  return failure();
3495 
3496  // 1) Fail if the chain is rank-reducing.
3497  int64_t rank = padOp.getSourceType().getRank();
3498  if (outerSliceOp.getSourceType().getRank() != rank) {
3499  return rewriter.notifyMatchFailure(padOp,
3500  "cannot fold rank-reducing chain");
3501  }
3502 
3503  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3504  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3505  return rewriter.notifyMatchFailure(
3506  padOp, "cannot fold non-unit stride ExtractSliceOps");
3507  }
3508 
3509  // 3) Fail if the tensor::PadOps have non-zero low padding.
3510  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3511  return rewriter.notifyMatchFailure(padOp,
3512  "cannot fold PadOps with low padding");
3513  }
3514 
3515  // 4) Fail if the tensor::PadOps padding values do not match.
3516  Attribute innerAttr, outerAttr;
3517  Value innerValue = padOp.getConstantPaddingValue();
3518  Value outerValue = outerPadOp.getConstantPaddingValue();
3519  if (!innerValue || !outerValue ||
3520  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3521  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3522  innerAttr != outerAttr) {
3523  return rewriter.notifyMatchFailure(
3524  padOp, "cannot fold PadOps with different padding values");
3525  }
3526 
3527  // 5) Fail if a dimension is padded by both tensor::PadOps.
3528  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3529  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3530  if (innerDims.anyCommon(outerDims)) {
3531  return rewriter.notifyMatchFailure(
3532  padOp, "cannot fold PadOps with common padding dimensions");
3533  }
3534 
3535  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3536  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3537  // for every dimension, and use the offset the other pair. Fail if no
3538  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3539  // exists.
3540  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3541  for (auto en : enumerate(newOffsets)) {
3542  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3543  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3544  if (!innerDims.test(en.index()) &&
3545  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3546  en.value() = outerOffset;
3547  continue;
3548  }
3549  if (!outerDims.test(en.index()) &&
3550  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3551  en.value() = innerOffset;
3552  continue;
3553  }
3554  return rewriter.notifyMatchFailure(
3555  padOp, "cannot find zero-offset and zero-padding pair");
3556  }
3557 
3558  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3559  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3560  // outer tensor::PadOp and fail if the size of the inner
3561  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3562  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3563  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3564  for (auto en : enumerate(newSizes)) {
3565  if (!outerDims.test(en.index()))
3566  continue;
3567  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3568  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3569  assert(ShapedType::isStatic(sourceSize) &&
3570  "expected padded dimension to have a static size");
3571  if (getConstantIntValue(sliceSize) != sourceSize) {
3572  return rewriter.notifyMatchFailure(
3573  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3574  "match the size of the outer padding");
3575  }
3576  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3577  }
3578 
3579  // Combine the high paddings of the two tensor::PadOps.
3580  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3581  for (auto en : enumerate(newHighPad)) {
3582  if (innerDims.test(en.index()))
3583  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3584  if (outerDims.test(en.index()))
3585  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3586  }
3587 
3588  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3589  // the two paddings in one step.
3590  auto newSliceOp = ExtractSliceOp::create(
3591  rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3592  newSizes, innerSliceOp.getMixedStrides());
3593  auto newPadOp = PadOp::create(
3594  rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3595  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3596  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3597  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3598  newPadOp.getRegion().begin());
3599  rewriter.replaceOp(padOp, newPadOp.getResult());
3600  return success();
3601  }
3602 };
3603 
3604 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3606 
3607  LogicalResult matchAndRewrite(PadOp padTensorOp,
3608  PatternRewriter &rewriter) const override {
3609  Value input = padTensorOp.getSource();
3610  if (!llvm::isa<RankedTensorType>(input.getType()))
3611  return failure();
3612  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3613  auto inputRank = inputDims.size();
3614 
3615  auto oldResultType =
3616  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3617  if (!oldResultType)
3618  return failure();
3619 
3620  auto outputDims = oldResultType.getShape();
3621 
3622  // Extract the static info from the high and low operands.
3623  SmallVector<int64_t> constOperandsLow;
3624  SmallVector<Value> newLows;
3625  for (auto operand : padTensorOp.getLow()) {
3626  APSInt intOp;
3627  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3628  constOperandsLow.push_back(ShapedType::kDynamic);
3629  newLows.push_back(operand);
3630  continue;
3631  }
3632  constOperandsLow.push_back(intOp.getExtValue());
3633  }
3634  SmallVector<int64_t> constOperandsHigh;
3635  SmallVector<Value> newHighs;
3636  for (auto operand : padTensorOp.getHigh()) {
3637  APSInt intOp;
3638  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3639  constOperandsHigh.push_back(ShapedType::kDynamic);
3640  newHighs.push_back(operand);
3641  continue;
3642  }
3643  constOperandsHigh.push_back(intOp.getExtValue());
3644  }
3645 
3646  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3647  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3648 
3649  // Verify the op is well-formed.
3650  if (inputDims.size() != outputDims.size() ||
3651  inputDims.size() != constLow.size() ||
3652  inputDims.size() != constHigh.size())
3653  return failure();
3654 
3655  auto lowCount = 0;
3656  auto highCount = 0;
3657  for (size_t i = 0; i < inputRank; i++) {
3658  if (constLow[i] == ShapedType::kDynamic)
3659  constLow[i] = constOperandsLow[lowCount++];
3660  if (constHigh[i] == ShapedType::kDynamic)
3661  constHigh[i] = constOperandsHigh[highCount++];
3662  }
3663 
3664  auto staticLow = ArrayRef<int64_t>(constLow);
3665  auto staticHigh = ArrayRef<int64_t>(constHigh);
3666 
3667  // Calculate the output sizes with the static information.
3668  SmallVector<int64_t> newOutDims;
3669  for (size_t i = 0; i < inputRank; i++) {
3670  if (outputDims[i] == ShapedType::kDynamic) {
3671  newOutDims.push_back(
3672  (staticLow[i] == ShapedType::kDynamic ||
3673  staticHigh[i] == ShapedType::kDynamic ||
3674  inputDims[i] == ShapedType::kDynamic
3675  ? ShapedType::kDynamic
3676  : inputDims[i] + staticLow[i] + staticHigh[i]));
3677  } else {
3678  newOutDims.push_back(outputDims[i]);
3679  }
3680  }
3681 
3682  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3683  llvm::all_of(newOutDims,
3684  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3685  return failure();
3686 
3687  // Rewrite the op using the new static type.
3688  auto newResultType = RankedTensorType::get(
3689  newOutDims, padTensorOp.getType().getElementType());
3690  auto newOp = PadOp::create(
3691  rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3692  staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3693  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3694 
3695  IRMapping mapper;
3696  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3697  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3698  newOp);
3699 
3700  return success();
3701  }
3702 };
3703 
3704 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3705 ///
3706 /// Example:
3707 ///
3708 /// ```mlir
3709 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3710 /// tensor.yield %val
3711 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3712 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3713 /// tensor.yield %val
3714 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3715 /// ```
3716 ///
3717 /// folds into:
3718 ///
3719 /// ```mlir
3720 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3721 /// tensor.yield %val
3722 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3723 /// ```
3724 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3726 
3727  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3728  PatternRewriter &rewriter) const override {
3729  if (padOp.getNofold()) {
3730  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3731  }
3732 
3733  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3734  if (!producerPad || producerPad.getNofold()) {
3735  return rewriter.notifyMatchFailure(
3736  padOp, "producer is not a foldable tensor.pad op");
3737  }
3738 
3739  // Fail if the tensor::PadOps padding values do not match.
3740  Value consumerPadValue = padOp.getConstantPaddingValue();
3741  Value producerPadValue = producerPad.getConstantPaddingValue();
3742  if (!consumerPadValue || !producerPadValue ||
3743  consumerPadValue != producerPadValue) {
3744  return rewriter.notifyMatchFailure(
3745  padOp,
3746  "cannot fold PadOps with different or non-constant padding values");
3747  }
3748 
3749  Location loc = padOp.getLoc();
3750  AffineExpr d0, d1;
3751  bindDims(rewriter.getContext(), d0, d1);
3752 
3753  // Combine the low/high paddings of the two tensor::PadOps.
3754  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3755  ArrayRef<OpFoldResult> producerPaddings) {
3756  SmallVector<OpFoldResult> sumPaddings;
3757  for (auto [consumerIndex, producerIndex] :
3758  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3759  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3760  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3761  }
3762  return sumPaddings;
3763  };
3764 
3765  SmallVector<OpFoldResult> newHighPad =
3766  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3767  SmallVector<OpFoldResult> newLowPad =
3768  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3769 
3770  auto newPadOp = tensor::PadOp::create(
3771  rewriter, padOp.getLoc(), padOp.getResultType(),
3772  producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3773  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3774  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3775  newPadOp.getRegion().begin());
3776  rewriter.replaceOp(padOp, newPadOp.getResult());
3777  return success();
3778  }
3779 };
3780 
3781 } // namespace
3782 
3783 LogicalResult
3785  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3786  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3787  SmallVector<OpFoldResult> lp = getMixedLowPad();
3788  SmallVector<OpFoldResult> hp = getMixedHighPad();
3789  for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3790  if (!getType().isDynamicDim(i)) {
3791  reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3792  continue;
3793  }
3794  Location loc = getLoc();
3795  Value dim = b.createOrFold<tensor::DimOp>(
3796  loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3797 
3798  AffineExpr d0, d1, d2;
3799  bindDims(b.getContext(), d0, d1, d2);
3800  reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3801  b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3802  }
3803  return success();
3804 }
3805 
3806 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3807  MLIRContext *context) {
3808  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3809  FoldOrthogonalPaddings, FoldStaticPadding,
3810  FoldConsecutiveConstantPadding>(context);
3811 }
3812 
3813 /// Return the padding value of the PadOp if it constant. In this context,
3814 /// "constant" means an actual constant or "defined outside of the block".
3815 ///
3816 /// Values are considered constant in three cases:
3817 /// - A ConstantLike value.
3818 /// - A basic block argument from a different block.
3819 /// - A value defined outside of the block.
3820 ///
3821 /// If the padding value is not constant, an empty Value is returned.
3822 Value PadOp::getConstantPaddingValue() {
3823  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3824  if (!yieldOp)
3825  return {};
3826  Value padValue = yieldOp.getValue();
3827  // Check if yield value is a constant.
3828  if (matchPattern(padValue, m_Constant()))
3829  return padValue;
3830  // Check if yield value is defined inside the PadOp block.
3831  if (padValue.getParentBlock() == &getRegion().front())
3832  return {};
3833  // Else: Yield value defined outside of the PadOp block.
3834  return padValue;
3835 }
3836 
3837 OpFoldResult PadOp::fold(FoldAdaptor) {
3838  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3839  !getNofold())
3840  return getSource();
3841  return {};
3842 }
3843 
3844 //===----------------------------------------------------------------------===//
3845 // ParallelInsertSliceOp
3846 //===----------------------------------------------------------------------===//
3847 
3848 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3849  InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3850  for (const auto &it :
3851  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3852  Operation &nextOp = it.value();
3853  if (&nextOp == getOperation())
3854  return parallelCombiningParent.getParentResult(it.index());
3855  }
3856  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3857 }
3858 
3859 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3860 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3861  Value source, Value dest,
3862  ArrayRef<OpFoldResult> offsets,
3863  ArrayRef<OpFoldResult> sizes,
3864  ArrayRef<OpFoldResult> strides,
3865  ArrayRef<NamedAttribute> attrs) {
3866  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3867  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3868  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3869  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3870  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3871  result.addAttributes(attrs);
3872  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3873  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3874  b.getDenseI64ArrayAttr(staticSizes),
3875  b.getDenseI64ArrayAttr(staticStrides));
3876 }
3877 
3878 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3879 /// packed into a Range vector.
3880 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3881  Value source, Value dest,
3882  ArrayRef<Range> ranges,
3883  ArrayRef<NamedAttribute> attrs) {
3884  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3885  build(b, result, source, dest, offsets, sizes, strides, attrs);
3886 }
3887 
3888 // Build a ParallelInsertSliceOp with dynamic entries.
3889 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3890  Value source, Value dest, ValueRange offsets,
3891  ValueRange sizes, ValueRange strides,
3892  ArrayRef<NamedAttribute> attrs) {
3893  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3894  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3895  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3896  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3897  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3898  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3899  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3900 }
3901 
3902 LogicalResult ParallelInsertSliceOp::verify() {
3903  if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3904  return this->emitError("expected InParallelOpInterface parent, got:")
3905  << *(getOperation()->getParentOp());
3906 
3907  // Verify result type against inferred type.
3908  RankedTensorType expectedType;
3909  SliceVerificationResult result =
3910  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3911  getStaticSizes(), getStaticStrides(), &expectedType);
3912  if (result != SliceVerificationResult::Success)
3913  return produceSliceErrorMsg(result, *this, expectedType);
3914 
3915  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3916  // to the destination tensor.
3918  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3919  getStaticStrides(), /*generateErrorMessage=*/true);
3920  if (!boundsResult.isValid)
3921  return getOperation()->emitError(boundsResult.errorMessage);
3922 
3923  return success();
3924 }
3925 
3926 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3927  RewritePatternSet &results, MLIRContext *context) {
3928  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3929  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3930  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3931 }
3932 
3933 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3934  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3935 }
3936 
3937 // ParallelCombiningOpInterface implementation.
3938 MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3939  return getDestMutable();
3940 }
3941 
3942 Operation *ParallelInsertSliceOp::getIteratingParent() {
3943  // Return the parent InParallelOpInterface's parent.
3944  if (auto combiningOp =
3945  dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3946  return combiningOp->getParentOp();
3947  return nullptr;
3948 }
3949 
3950 //===----------------------------------------------------------------------===//
3951 // ScatterOp
3952 //===----------------------------------------------------------------------===//
3953 
3954 void ScatterOp::getAsmResultNames(
3955  function_ref<void(Value, StringRef)> setNameFn) {
3956  setNameFn(getResult(), "scatter");
3957 }
3958 
3959 LogicalResult ScatterOp::verify() {
3960  int64_t destRank = getDestType().getRank();
3961  ArrayRef<int64_t> scatterDims = getScatterDims();
3962  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3963  getIndicesType().getShape(), destRank,
3964  "scatter", "dest")))
3965  return failure();
3966 
3967  if (!getUnique())
3968  return emitOpError("requires 'unique' attribute to be set");
3969  // TODO: we could also check statically that there are fewer leading index
3970  // tensor dims than the dest dims. If this is not the case, the unique
3971  // attribute cannot be true.
3972 
3973  // Use the GatherOp::inferResultType on the `dest` type and verify the
3974  // expected type matches the source type.
3975  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3976  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3977  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3978  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3979  if (getSourceType() != expectedSourceType &&
3980  getSourceType() != expectedRankReducedSourceType) {
3981  return emitOpError("source type "
3982  "mismatch: "
3983  "expected ")
3984  << expectedSourceType << " or its rank-reduced variant "
3985  << expectedRankReducedSourceType << " (got: " << getSourceType()
3986  << ")";
3987  }
3988 
3989  return success();
3990 }
3991 
3992 //===----------------------------------------------------------------------===//
3993 // SplatOp
3994 //===----------------------------------------------------------------------===//
3995 
3996 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3997  Type aggregateType, ValueRange dynamicSizes) {
3998  build(builder, result, aggregateType, element, dynamicSizes);
3999 }
4000 
4001 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4002  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4003  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4004  build(builder, result, aggregateType, element, dynamicSizes);
4005 }
4006 
4007 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4008  ArrayRef<OpFoldResult> sizes) {
4009  SmallVector<int64_t> staticShape;
4010  SmallVector<Value> dynamicSizes;
4011  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4012  build(builder, result, element, staticShape, dynamicSizes);
4013 }
4014 
4015 void SplatOp::getAsmResultNames(
4016  function_ref<void(Value, StringRef)> setNameFn) {
4017  setNameFn(getResult(), "splat");
4018 }
4019 
4020 LogicalResult SplatOp::verify() {
4021  if (getType().getNumDynamicDims() != getDynamicSizes().size())
4022  return emitOpError("incorrect number of dynamic sizes, has ")
4023  << getDynamicSizes().size() << ", expected "
4024  << getType().getNumDynamicDims();
4025  return success();
4026 }
4027 
4028 LogicalResult
4030  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4031  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4032  unsigned ctr = 0;
4033  for (int64_t i = 0; i < getType().getRank(); ++i) {
4034  if (getType().isDynamicDim(i)) {
4035  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4036  } else {
4037  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4038  }
4039  }
4040  return success();
4041 }
4042 
4043 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4044  auto constOperand = adaptor.getInput();
4045  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4046  return {};
4047 
4048  // Do not fold if the splat is not statically shaped
4049  if (!getType().hasStaticShape())
4050  return {};
4051 
4052  // SplatElementsAttr::get treats single value for second arg as being a
4053  // splat.
4054  return SplatElementsAttr::get(getType(), {constOperand});
4055 }
4056 
4057 //===----------------------------------------------------------------------===//
4058 // Common Canonicalizers and Folders.
4059 //===----------------------------------------------------------------------===//
4060 static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4061  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4062  // 2. Exclude DPS ops that are also LoopLike from this interface as they
4063  // might need special handling of attached regions.
4064  if (isa<InsertSliceOp>(op.getOperation()) ||
4065  isa<LoopLikeOpInterface>(op.getOperation()))
4066  return false;
4067 
4068  return hasFoldableTensorCastOperand(op);
4069 }
4070 
4071 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4072 /// the `tensor.cast` has source that is more static than the consuming op.
4073 ///
4074 /// Example:
4075 /// ```mlir
4076 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4077 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4078 /// ```
4079 ///
4080 /// folds into:
4081 ///
4082 /// ```mlir
4083 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4084 /// ```
4085 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4086 /// can add the pattern to their canonicalizers.
4088  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4090  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4091 
4092  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4093  PatternRewriter &rewriter) const override {
4094 
4095  // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4096  // for that instead.
4097  if (!foldTensorCastPrecondition(op) ||
4098  isa<linalg::RelayoutOpInterface>(*op))
4099  return failure();
4100 
4101  SmallVector<Type> newResultTypes(op->getResultTypes());
4102  SmallVector<Value> newOperands =
4103  getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4104 
4105  // Clone op
4106  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4107 
4108  SmallVector<Value, 4> replacements;
4109  replacements.reserve(newOp->getNumResults());
4110  for (auto [oldResult, newResult] :
4111  llvm::zip(op->getResults(), newOp->getResults())) {
4112  if (newResult.getType() != oldResult.getType()) {
4113  replacements.push_back(tensor::CastOp::create(
4114  rewriter, op->getLoc(), oldResult.getType(), newResult));
4115  } else {
4116  replacements.push_back(newResult);
4117  }
4118  }
4119  rewriter.replaceOp(op, replacements);
4120 
4121  return success();
4122  }
4123 };
4124 
4125 //===----------------------------------------------------------------------===//
4126 // TensorDialect
4127 //===----------------------------------------------------------------------===//
4128 
4129 void TensorDialect::getCanonicalizationPatterns(
4130  RewritePatternSet &results) const {
4132 }
4133 
4134 //===----------------------------------------------------------------------===//
4135 // TableGen'd op method definitions
4136 //===----------------------------------------------------------------------===//
4137 
4138 #define GET_OP_CLASSES
4139 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:420
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
Definition: TensorOps.cpp:1372
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1558
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2437
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:4060
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2890
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2730
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2752
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1816
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2839
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:184
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2910
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:2043
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:367
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:377
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:230
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:241
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:360
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2690
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:353
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:369
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3182
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2777
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:128
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
Definition: TensorOps.cpp:1443
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:79
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:90
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4088
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4092
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2709
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2710
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2697
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2698
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.