MLIR  21.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/TypeUtilities.h"
31 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallBitVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/MathExtras.h"
38 #include <algorithm>
39 #include <optional>
40 
41 using namespace mlir;
42 using namespace mlir::tensor;
43 
44 using llvm::divideCeilSigned;
45 using llvm::divideFloorSigned;
46 using llvm::mod;
47 
48 /// Materialize a single constant operation from a given attribute value with
49 /// the desired resultant type.
51  Attribute value, Type type,
52  Location loc) {
53  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
54  return op;
55  if (complex::ConstantOp::isBuildableWith(value, type))
56  return builder.create<complex::ConstantOp>(loc, type,
57  llvm::cast<ArrayAttr>(value));
58  return nullptr;
59 }
60 
62  int64_t dim) {
63  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
64  if (tensorType.isDynamicDim(dim))
65  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
66 
67  return builder.getIndexAttr(tensorType.getDimSize(dim));
68 }
69 
71  Location loc, Value value) {
72  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
74  for (int64_t i = 0; i < tensorType.getRank(); ++i)
75  result.push_back(getMixedSize(builder, loc, value, i));
76  return result;
77 }
78 
80  OpResult opResult) {
81  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
82  assert(tensorType && "expected tensor type");
83 
84  // If the op has a destination, it implements DestinationStyleOpInterface and
85  // we can query the destination operand from that interface.
86  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
87  if (destOp)
88  return destOp.getTiedOpOperand(opResult)->get();
89 
90  // Otherwise, create a new destination tensor with the same shape.
92  b.setInsertionPoint(opResult.getDefiningOp());
93 
94  // Compute sizes.
95  SmallVector<OpFoldResult> mixedSizes;
96  if (!tensorType.hasStaticShape()) {
97  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
98  ReifiedRankedShapedTypeDims reifiedShapes;
99  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
100  return failure();
101  mixedSizes = reifiedShapes[opResult.getResultNumber()];
102  } else {
103  // Static shape: Take static sizes directly.
104  for (int64_t sz : tensorType.getShape())
105  mixedSizes.push_back(b.getIndexAttr(sz));
106  }
107 
108  // Create empty tensor.
109  Value emptyTensor =
110  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
111  return emptyTensor;
112 }
113 
115  Operation *op,
116  SmallVector<Value> &result) {
117  for (OpResult opResult : op->getResults()) {
118  if (llvm::isa<TensorType>(opResult.getType())) {
119  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
120  if (failed(destination))
121  return failure();
122  result.push_back(*destination);
123  }
124  }
125  return success();
126 }
127 
129  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
130  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
131  return rtp1.getShape() == rtp2.getShape() &&
132  rtp1.getElementType() == rtp2.getElementType();
133  return false;
134  }
135  return tp1 == tp2; // default implementation
136 }
137 
138 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
139 /// rank-extending tensor.insert_slice op.
140 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
141  ArrayRef<OpFoldResult> mixedSizes) {
142  llvm::SmallBitVector droppedDims(mixedSizes.size());
143  int64_t shapePos = reducedShape.size() - 1;
144 
145  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
146  size_t idx = mixedSizes.size() - size.index() - 1;
147  // Rank-reduced dims must have a static unit dimension.
148  bool isStaticUnitSize =
149  isa<Attribute>(size.value()) &&
150  llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 
152  if (shapePos < 0) {
153  // There are no more dims in the reduced shape. All remaining sizes must
154  // be rank-reduced dims.
155  assert(isStaticUnitSize && "expected unit dim");
156  droppedDims.set(idx);
157  continue;
158  }
159 
160  // Dim is preserved if the size is not a static 1.
161  if (!isStaticUnitSize) {
162  --shapePos;
163  continue;
164  }
165 
166  // Dim is preserved if the reduced shape dim is also 1.
167  if (reducedShape[shapePos] == 1) {
168  --shapePos;
169  continue;
170  }
171 
172  // Otherwise: Dim is dropped.
173  droppedDims.set(idx);
174  }
175 
176  assert(shapePos < 0 && "dimension mismatch");
177  return droppedDims;
178 }
179 
180 /// Given a ranked tensor type and a range of values that defines its dynamic
181 /// dimension sizes, turn all dynamic sizes that have a constant value into
182 /// static dimension sizes.
183 static RankedTensorType
184 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
185  SmallVector<Value> &foldedDynamicSizes) {
186  SmallVector<int64_t> staticShape(type.getShape());
187  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
188  "incorrect number of dynamic sizes");
189 
190  // Compute new static and dynamic sizes.
191  unsigned ctr = 0;
192  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
193  if (type.isDynamicDim(i)) {
194  Value dynamicSize = dynamicSizes[ctr++];
195  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
196  if (cst.has_value()) {
197  // Dynamic size must be non-negative.
198  if (cst.value() < 0) {
199  foldedDynamicSizes.push_back(dynamicSize);
200  continue;
201  }
202  staticShape[i] = *cst;
203  } else {
204  foldedDynamicSizes.push_back(dynamicSize);
205  }
206  }
207  }
208 
209  return RankedTensorType::get(staticShape, type.getElementType(),
210  type.getEncoding());
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // BitcastOp
215 //===----------------------------------------------------------------------===//
216 
217 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
218  if (inputs.size() != 1 || outputs.size() != 1)
219  return false;
220  Type a = inputs.front(), b = outputs.front();
221  auto aT = dyn_cast<TensorType>(a);
222  auto bT = dyn_cast<TensorType>(b);
223  if (!aT || !bT)
224  return false;
225 
226  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
227  return false;
228 
229  return succeeded(verifyCompatibleShape(aT, bT));
230 }
231 
232 namespace {
233 
234 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
235 /// operation.
236 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
238 
239  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
240  PatternRewriter &rewriter) const final {
241  auto tensorBitcastOperand =
242  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
243  if (!tensorBitcastOperand)
244  return failure();
245 
246  auto resultType = cast<TensorType>(tensorBitcast.getType());
247  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
248  tensorBitcastOperand.getOperand());
249  return success();
250  }
251 };
252 
253 } // namespace
254 
255 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
256  MLIRContext *context) {
257  results.add<ChainedTensorBitcast>(context);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // CastOp
262 //===----------------------------------------------------------------------===//
263 
264 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
265  setNameFn(getResult(), "cast");
266 }
267 
268 /// Returns true if `target` is a ranked tensor type that preserves static
269 /// information available in the `source` ranked tensor type.
271  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
272  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
273 
274  // Requires RankedTensorType.
275  if (!sourceType || !targetType)
276  return false;
277 
278  // Requires same elemental type.
279  if (sourceType.getElementType() != targetType.getElementType())
280  return false;
281 
282  // Requires same rank.
283  if (sourceType.getRank() != targetType.getRank())
284  return false;
285 
286  // Requires same encoding.
287  if (sourceType.getEncoding() != targetType.getEncoding())
288  return false;
289 
290  // If cast is towards more static sizes along any dimension, don't fold.
291  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
292  if (!ShapedType::isDynamic(std::get<0>(t)) &&
293  ShapedType::isDynamic(std::get<1>(t)))
294  return false;
295  }
296 
297  return true;
298 }
299 
300 /// Determines whether tensor::CastOp casts to a more dynamic version of the
301 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
302 /// implement canonicalization patterns for ops in different dialects that may
303 /// consume the results of tensor.cast operations. Such foldable tensor.cast
304 /// operations are typically inserted as `slice` ops and are canonicalized,
305 /// to preserve the type compatibility of their uses.
306 ///
307 /// Returns true when all conditions are met:
308 /// 1. source and result are ranked tensors with same element type and rank.
309 /// 2. the tensor type has more static information than the result
310 ///
311 /// Example:
312 /// ```mlir
313 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
314 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
315 /// ```
316 ///
317 /// folds into:
318 ///
319 /// ```mlir
320 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
321 /// ```
323  if (!castOp)
324  return false;
325 
326  // Can fold if the source of cast has at least as much static information as
327  // its results.
328  return preservesStaticInformation(castOp.getType(),
329  castOp.getSource().getType());
330 }
331 
332 /// Determines whether the tensor::CastOp casts to a more static version of the
333 /// source tensor. This is useful to fold into a producing op and implement
334 /// canonicalization patterns with the `tensor.cast` op as the root, but
335 /// producer being from different dialects. Returns true when all conditions are
336 /// met:
337 /// 1. source and result and ranked tensors with same element type and rank.
338 /// 2. the result type has more static information than the source.
339 ///
340 /// Example:
341 /// ```mlir
342 /// %1 = producer ... : tensor<?x?xf32>
343 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
344 /// ```
345 ///
346 /// can be canonicalized to :
347 ///
348 /// ```mlir
349 /// %2 = producer ... : tensor<8x16xf32>
350 /// ```
351 /// Not all ops might be canonicalizable this way, but for those that can be,
352 /// this method provides a check that it is worth doing the canonicalization.
354  if (!castOp)
355  return false;
356  return preservesStaticInformation(castOp.getSource().getType(),
357  castOp.getType());
358 }
359 
361  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
362  if (llvm::isa<BlockArgument>(opOperand.get()))
363  return false;
364  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
365  return castOp && canFoldIntoConsumerOp(castOp);
366  });
367 }
368 
370  DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
371  SmallVector<Value> newOperands;
372  newOperands.reserve(op->getNumOperands());
373 
374  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
375 
376  // Assumes that the result has dpsInits followed by nonDpsInits.
377  int64_t dpsInitIdx = 0;
378  for (OpOperand &opOperand : op->getOpOperands()) {
379  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
380  bool fold = canFoldIntoConsumerOp(tensorCastOp);
381  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
382  if (op.isDpsInit(&opOperand) &&
383  !llvm::isa<MemRefType>(newOperands.back().getType()))
384  newResTy[dpsInitIdx++] = newOperands.back().getType();
385  }
386  return newOperands;
387 }
388 
389 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
390 /// that can be folded.
392  bool folded = false;
393  for (OpOperand &operand : op->getOpOperands()) {
394  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
395  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
396  operand.set(castOp.getOperand());
397  folded = true;
398  }
399  }
400  return success(folded);
401 }
402 
403 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
404  if (inputs.size() != 1 || outputs.size() != 1)
405  return false;
406  Type a = inputs.front(), b = outputs.front();
407  auto aT = llvm::dyn_cast<TensorType>(a);
408  auto bT = llvm::dyn_cast<TensorType>(b);
409  if (!aT || !bT)
410  return false;
411 
412  if (aT.getElementType() != bT.getElementType())
413  return false;
414 
415  return succeeded(verifyCompatibleShape(aT, bT));
416 }
417 
418 /// Compute a TensorType that has the joined shape knowledge of the two
419 /// given TensorTypes. The element types need to match.
421  assert(one.getElementType() == two.getElementType());
422 
423  if (!one.hasRank())
424  return two;
425  if (!two.hasRank())
426  return one;
427 
428  int64_t rank = one.getRank();
429  if (rank != two.getRank())
430  return {};
431 
433  join.reserve(rank);
434  for (int64_t i = 0; i < rank; ++i) {
435  if (one.isDynamicDim(i)) {
436  join.push_back(two.getDimSize(i));
437  continue;
438  }
439  if (two.isDynamicDim(i)) {
440  join.push_back(one.getDimSize(i));
441  continue;
442  }
443  if (one.getDimSize(i) != two.getDimSize(i))
444  return {};
445  join.push_back(one.getDimSize(i));
446  }
447  return RankedTensorType::get(join, one.getElementType());
448 }
449 
450 namespace {
451 
452 /// Replaces chains of two tensor.cast operations by a single tensor.cast
453 /// operation if doing so does not remove runtime constraints.
454 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
456 
457  LogicalResult matchAndRewrite(CastOp tensorCast,
458  PatternRewriter &rewriter) const final {
459  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
460 
461  if (!tensorCastOperand)
462  return failure();
463 
464  auto sourceType =
465  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
466  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
467  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
468 
469  // We can remove the intermediate cast if joining all three produces the
470  // same result as just joining the source and result shapes.
471  auto firstJoin =
472  joinShapes(joinShapes(sourceType, intermediateType), resultType);
473 
474  // The join might not exist if the cast sequence would fail at runtime.
475  if (!firstJoin)
476  return failure();
477 
478  // The newJoin always exists if the above join exists, it might just contain
479  // less information. If so, we cannot drop the intermediate cast, as doing
480  // so would remove runtime checks.
481  auto newJoin = joinShapes(sourceType, resultType);
482  if (firstJoin != newJoin)
483  return failure();
484 
485  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
486  tensorCastOperand.getOperand());
487  return success();
488  }
489 };
490 
491 /// Fold tensor.cast into tesor.extract_slice producer.
492 /// Example:
493 /// ```
494 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
495 /// tensor<128x512xf32> to tensor<?x512xf32>
496 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
497 /// ```
498 /// ->
499 /// ```
500 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
501 /// tensor<128x512xf32> to tensor<16x512xf32>
502 /// ```
503 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
505 
506  LogicalResult matchAndRewrite(CastOp tensorCast,
507  PatternRewriter &rewriter) const final {
508  auto extractOperand =
509  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510 
511  // Cannot fold cast to unranked tensor.
512  auto rankedResultType =
513  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
514  if (!rankedResultType)
515  return failure();
516 
517  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
518  rankedResultType.getShape() ==
519  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
520  .getShape())
521  return failure();
522 
523  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
524  auto dimMask = computeRankReductionMask(
525  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
526  size_t dimIndex = 0;
527  for (size_t i = 0, e = sizes.size(); i < e; i++) {
528  if (dimMask && dimMask->count(i))
529  continue;
530  int64_t dim = rankedResultType.getShape()[dimIndex++];
531  if (ShapedType::isDynamic(dim))
532  continue;
533  sizes[i] = rewriter.getIndexAttr(dim);
534  }
535 
536  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
537  tensorCast, rankedResultType, extractOperand.getSource(),
538  extractOperand.getMixedOffsets(), sizes,
539  extractOperand.getMixedStrides());
540  return success();
541  }
542 };
543 
544 } // namespace
545 
546 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
547  MLIRContext *context) {
548  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // ConcatOp
553 //===----------------------------------------------------------------------===//
554 
555 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
556  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
557  auto tensorTypes =
558  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
559  return llvm::cast<RankedTensorType>(type);
560  }));
561  int64_t concatRank = tensorTypes[0].getRank();
562 
563  // The concatenation dim must be in the range [0, rank).
564  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
565 
566  SmallVector<int64_t> sizes(concatRank);
567  for (int64_t i = 0, e = concatRank; i < e; ++i) {
568  if (i == dim)
569  continue;
570  SaturatedInteger size;
571  for (auto tensorType : tensorTypes)
572  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
573  sizes[i] = size.asInteger();
574  }
575  auto concatSize = SaturatedInteger::wrap(0);
576  for (auto tensorType : tensorTypes)
577  concatSize =
578  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
579  sizes[dim] = concatSize.asInteger();
580  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
581 }
582 
583 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
584  ValueRange inputs) {
585  FailureOr<RankedTensorType> resultType =
586  inferResultType(dim, inputs.getTypes());
587  assert(succeeded(resultType) && "failed to infer concatenation result type");
588  build(builder, result, *resultType, dim, inputs);
589 }
590 
591 LogicalResult ConcatOp::verify() {
592  if (getInputs().size() < 1)
593  return emitOpError("requires at least one input");
594 
596  for (auto input : getInputs())
597  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
598 
599  RankedTensorType resultType = getResultType();
600  int64_t resultRank = getRank();
601  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602  return type.getRank() != resultRank;
603  }))
604  return emitOpError("rank of concatenated inputs must match result rank");
605 
606  Type resultElementType = resultType.getElementType();
607  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608  return type.getElementType() != resultElementType;
609  }))
610  return emitOpError("inputs and result element type must match");
611 
612  int64_t dim = getDim();
613  if (dim >= resultRank)
614  return emitOpError("concatenation dim must be less than the tensor rank");
615 
616  SmallVector<int64_t> sizes(resultRank);
617  for (int64_t i = 0, e = resultRank; i < e; ++i) {
618  if (i == dim)
619  continue;
620  SaturatedInteger size;
621  for (auto tensorType : inputTypes) {
622  FailureOr<SaturatedInteger> maybeSize =
623  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
624  if (failed(maybeSize))
625  return emitOpError("static concatenation size mismatch along ")
626  << "non-concatenated dimension " << i;
627  size = *maybeSize;
628  }
629  sizes[i] = size.asInteger();
630  }
631  auto concatSize = SaturatedInteger::wrap(0);
632  for (auto tensorType : inputTypes)
633  concatSize =
634  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
635  sizes[dim] = concatSize.asInteger();
636  auto inferredResultType =
637  RankedTensorType::get(sizes, inputTypes[0].getElementType());
638 
639  for (auto [inferredSize, actualSize] :
640  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642  ShapedType::isDynamic(actualSize);
643  if (!hasDynamic && inferredSize != actualSize)
644  return emitOpError("result type ")
645  << resultType << "does not match inferred shape "
646  << inferredResultType << " static sizes";
647  }
648 
649  return success();
650 }
651 
652 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
653  size_t numInputs = getInputs().size();
654  uint64_t concatDim = getDim();
655 
657  inputShapes.reserve(numInputs);
658  SmallVector<OpFoldResult> concatOffsets;
659  concatOffsets.reserve(numInputs);
660  SmallVector<OpFoldResult> outputShape;
661 
662  AffineExpr addExpr =
663  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
664  OpFoldResult zero = builder.getIndexAttr(0);
665  Location loc = getLoc();
666  for (auto [index, input] : llvm::enumerate(getInputs())) {
667  SmallVector<OpFoldResult> inputShape =
668  tensor::getMixedSizes(builder, input.getLoc(), input);
669  if (index == 0) {
670  outputShape = inputShape;
671  concatOffsets.push_back(zero);
672  } else {
673  concatOffsets.push_back(outputShape[concatDim]);
674  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
675  builder, loc, addExpr,
676  {outputShape[concatDim], inputShape[concatDim]});
677  }
678  inputShapes.emplace_back(std::move(inputShape));
679  }
680 
681  Value replacement = builder.create<tensor::EmptyOp>(
682  loc, outputShape, getType().getElementType());
683 
684  int64_t rank = getType().getRank();
685  OpFoldResult one = builder.getIndexAttr(1);
686  SmallVector<OpFoldResult> strides(rank, one);
687  SmallVector<OpFoldResult> offsets(rank, zero);
688  for (auto [index, input] : llvm::enumerate(getInputs())) {
689  offsets[concatDim] = concatOffsets[index];
690  auto insertSlice = builder.create<tensor::InsertSliceOp>(
691  loc, input, replacement, offsets, inputShapes[index], strides);
692  replacement = insertSlice.getResult();
693  }
694  if (replacement.getType() != getType()) {
695  replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
696  }
697  return SmallVector<Value>{replacement};
698 }
699 
700 LogicalResult
702  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
703  ValueRange inputs = getInputs();
704  int64_t dim = getDim();
705  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
706 
707  Value init = inputs[0];
708  int64_t rank = getType().getRank();
709 
710  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
711 
712  // Pre-populate the result sizes with as much static information as possible
713  // from the given result type, as well as the inferred result type, otherwise
714  // use the dim sizes from the first input.
715  for (int64_t i = 0; i < rank; ++i) {
716  if (i == dim)
717  continue;
718  if (!getType().isDynamicDim(i)) {
719  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
720  } else if (!inferredResultType.isDynamicDim(i)) {
721  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
722  builder, getLoc(),
723  builder.getIndexAttr(inferredResultType.getDimSize(i)));
724  } else {
725  reifiedReturnShapes[0][i] =
726  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
727  }
728  }
729 
730  if (getType().isDynamicDim(dim)) {
731  // Take the sum of the input sizes along the concatenated dim.
732  AffineExpr sum = builder.getAffineDimExpr(0);
733  SmallVector<OpFoldResult> sizes = {
734  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
735  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
736  sum = sum + builder.getAffineDimExpr(idx + 1);
737  sizes.push_back(
738  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
739  }
740  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
741  builder, getLoc(),
742  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
743  } else {
744  // If the result shape is static along the concatenated dim, use the static
745  // shape.
746  reifiedReturnShapes[0][dim] =
747  builder.getIndexAttr(getType().getDimSize(dim));
748  }
749  return success();
750 }
751 
752 void ConcatOp::getAsmResultNames(
753  function_ref<void(Value, StringRef)> setNameFn) {
754  setNameFn(getResult(), "concat");
755 }
756 
757 OpFoldResult ConcatOp::fold(FoldAdaptor) {
758  ValueRange inputs = getInputs();
759  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
760  return inputs[0];
761  return {};
762 }
763 
764 namespace {
765 /// Fold a concat op with a single input to a cast.
766 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
768 
769  LogicalResult matchAndRewrite(ConcatOp concatOp,
770  PatternRewriter &rewriter) const override {
771  if (concatOp.getInputs().size() != 1)
772  return failure();
773  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
774  concatOp.getInputs()[0]);
775  return success();
776  }
777 };
778 
779 /// Propagate static shapes into the operands of a `tensor.concat`.
780 ///
781 /// `tensor.concat` requires every operand to match on all dimensions except the
782 /// concatenation dimension. If one operand is already static in those
783 /// dimensions, the other operands may safely be refined to that same static
784 /// shape.
785 ///
786 /// Example:
787 ///
788 /// ```mlir
789 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790 /// tensor<?x12xi32>
791 /// ```
792 /// ->
793 /// ```mlir
794 /// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795 /// %2 = tensor.concat dim(0) %0, %cast :
796 /// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797 /// ```
798 struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800 
801  LogicalResult matchAndRewrite(ConcatOp concatOp,
802  PatternRewriter &rewriter) const override {
803  int64_t dim = concatOp.getDim();
804  RankedTensorType inferredResultType =
805  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
806 
807  // Find operands for which a more static shape can be inferred.
808  LogicalResult matched = failure();
809  // Inferred operand shapes are identical in every dimension except the
810  // concatenation dimension.
811  SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812  for (auto [operandIdx, operandType] :
813  llvm::enumerate(concatOp->getOperandTypes())) {
814  // Compute inferred type for operand.
815  inferredOperandShape[dim] =
816  cast<RankedTensorType>(operandType).getDimSize(dim);
817  auto inferredOperandType = RankedTensorType::get(
818  inferredOperandShape, inferredResultType.getElementType());
819 
820  // Check if inferred type is more static.
821  if (!preservesStaticInformation(inferredOperandType, operandType)) {
822  matched = success();
823 
824  // Use refined operand type and create cast from original operand.
825  auto castOp =
826  rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
827  concatOp.getOperand(operandIdx));
828  rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
829  concatOp->setOperand(operandIdx, castOp->getResult(0));
830  });
831  }
832  }
833 
834  return matched;
835  }
836 };
837 
838 // Ensure `tensor.concat`'s result type is at least as static as can be inferred
839 // from its operand types.
840 ///
841 /// Example:
842 /// ```mlir
843 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
844 /// tensor<?x?xi32>
845 /// ```
846 /// ->
847 /// ```mlir
848 /// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
849 /// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
850 /// tensor<?x?xi32>
851 /// ```
852 struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
854 
855  LogicalResult matchAndRewrite(ConcatOp concatOp,
856  PatternRewriter &rewriter) const override {
857  int64_t dim = concatOp.getDim();
858  RankedTensorType inferredResultType =
859  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
860 
861  // The result type should be at least as static as inferred result type.
862  if (preservesStaticInformation(inferredResultType,
863  concatOp.getResultType())) {
864  return failure();
865  }
866 
867  auto newConcatOp = rewriter.create<ConcatOp>(
868  concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
869  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
870  newConcatOp);
871 
872  return success();
873  }
874 };
875 } // namespace
876 
877 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
878  MLIRContext *context) {
879  results
880  .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
881  context);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // DimOp
886 //===----------------------------------------------------------------------===//
887 
888 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
889  setNameFn(getResult(), "dim");
890 }
891 
892 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
893  int64_t index) {
894  auto loc = result.location;
895  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
896  build(builder, result, source, indexValue);
897 }
898 
899 std::optional<int64_t> DimOp::getConstantIndex() {
900  return getConstantIntValue(getIndex());
901 }
902 
903 Speculation::Speculatability DimOp::getSpeculatability() {
904  auto constantIndex = getConstantIndex();
905  if (!constantIndex)
907 
908  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
909  if (!rankedSourceType)
911 
912  if (rankedSourceType.getRank() <= constantIndex)
914 
916 }
917 
918 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
919  SetIntLatticeFn setResultRange) {
920  setResultRange(getResult(),
921  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
922 }
923 
924 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
925  // All forms of folding require a known index.
926  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
927  if (!index)
928  return {};
929 
930  // Folding for unranked types (UnrankedTensorType) is not supported.
931  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
932  if (!tensorType)
933  return {};
934 
935  // Out of bound indices produce undefined behavior but are still valid IR.
936  // Don't choke on them.
937  int64_t indexVal = index.getInt();
938  if (indexVal < 0 || indexVal >= tensorType.getRank())
939  return {};
940 
941  // Fold if the shape extent along the given index is known.
942  if (!tensorType.isDynamicDim(index.getInt())) {
943  Builder builder(getContext());
944  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
945  }
946 
947  Operation *definingOp = getSource().getDefiningOp();
948 
949  // Fold dim to the operand of tensor.generate.
950  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
951  auto resultType =
952  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
953  // The case where the type encodes the size of the dimension is handled
954  // above.
955  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
956 
957  // Find the operand of the fromElements that corresponds to this index.
958  auto dynExtents = fromElements.getDynamicExtents().begin();
959  for (auto dim : resultType.getShape().take_front(index.getInt()))
960  if (ShapedType::isDynamic(dim))
961  dynExtents++;
962 
963  return Value{*dynExtents};
964  }
965 
966  // The size at the given index is now known to be a dynamic size.
967  unsigned unsignedIndex = index.getValue().getZExtValue();
968 
969  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
970  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
971  // `resolve-shaped-type-result-dims` pass.
972  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
973  sliceOp.isDynamicSize(unsignedIndex)) {
974  return {sliceOp.getDynamicSize(unsignedIndex)};
975  }
976  }
977 
978  // dim(cast) -> dim
979  if (succeeded(foldTensorCast(*this)))
980  return getResult();
981 
982  return {};
983 }
984 
985 namespace {
986 /// Fold dim of a cast into the dim of the source of the tensor cast.
987 struct DimOfCastOp : public OpRewritePattern<DimOp> {
989 
990  LogicalResult matchAndRewrite(DimOp dimOp,
991  PatternRewriter &rewriter) const override {
992  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
993  if (!castOp)
994  return failure();
995  Value newSource = castOp.getOperand();
996  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
997  return success();
998  }
999 };
1000 
1001 /// Fold dim of a destination passing style op into the dim of the corresponding
1002 /// init.
1003 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1005 
1006  LogicalResult matchAndRewrite(DimOp dimOp,
1007  PatternRewriter &rewriter) const override {
1008  auto source = dimOp.getSource();
1009  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010  if (!destOp)
1011  return failure();
1012 
1013  auto resultIndex = cast<OpResult>(source).getResultNumber();
1014  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1015 
1016  rewriter.modifyOpInPlace(
1017  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1018  return success();
1019  }
1020 };
1021 
1022 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1023 /// operand.
1024 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1026 
1027  LogicalResult matchAndRewrite(DimOp dim,
1028  PatternRewriter &rewriter) const override {
1029  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1030 
1031  if (!reshape)
1032  return failure();
1033 
1034  // Since tensors are immutable we don't need to worry about where to place
1035  // the extract call
1036  rewriter.setInsertionPointAfter(dim);
1037  Location loc = dim.getLoc();
1038  Value extract =
1039  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
1040  if (extract.getType() != dim.getType())
1041  extract =
1042  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
1043  rewriter.replaceOp(dim, extract);
1044  return success();
1045  }
1046 };
1047 } // namespace
1048 
1049 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1050  MLIRContext *context) {
1051  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1052 }
1053 
1054 //===----------------------------------------------------------------------===//
1055 // EmptyOp
1056 //===----------------------------------------------------------------------===//
1057 
1058 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1059  ArrayRef<int64_t> staticShape, Type elementType,
1060  Attribute encoding) {
1061  assert(all_of(staticShape,
1062  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
1063  "expected only static sizes");
1064  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1065 }
1066 
1067 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1068  ArrayRef<int64_t> staticShape, Type elementType,
1069  ValueRange dynamicSizes, Attribute encoding) {
1070  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1071  build(builder, result, tensorType, dynamicSizes);
1072 }
1073 
1074 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1075  ArrayRef<OpFoldResult> sizes, Type elementType,
1076  Attribute encoding) {
1077  SmallVector<int64_t> staticShape;
1078  SmallVector<Value> dynamicSizes;
1079  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1080  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1081 }
1082 
1083 LogicalResult EmptyOp::verify() {
1084  if (getType().getNumDynamicDims() != getDynamicSizes().size())
1085  return emitOpError("incorrect number of dynamic sizes, has ")
1086  << getDynamicSizes().size() << ", expected "
1087  << getType().getNumDynamicDims();
1088  return success();
1089 }
1090 
1091 LogicalResult
1093  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1094  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1095  unsigned ctr = 0;
1096  for (int64_t i = 0; i < getType().getRank(); ++i) {
1097  if (getType().isDynamicDim(i)) {
1098  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1099  } else {
1100  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1101  }
1102  }
1103  return success();
1104 }
1105 
1106 Value EmptyOp::getDynamicSize(unsigned idx) {
1107  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1108  unsigned ctr = 0;
1109  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1110  if (getType().isDynamicDim(i))
1111  ++ctr;
1112  return getDynamicSizes()[ctr];
1113 }
1114 
1117  unsigned ctr = 0;
1118  OpBuilder b(getContext());
1119  for (int64_t i = 0; i < getType().getRank(); ++i) {
1120  if (getType().isDynamicDim(i)) {
1121  result.push_back(getDynamicSizes()[ctr++]);
1122  } else {
1123  result.push_back(b.getIndexAttr(getType().getShape()[i]));
1124  }
1125  }
1126  return result;
1127 }
1128 
1129 namespace {
1130 /// Change the type of the result of a `tensor.empty` by making the result
1131 /// type statically sized along dimensions that in the original operation were
1132 /// defined as dynamic, but the size was defined using a `constant` op. For
1133 /// example
1134 ///
1135 /// %c5 = arith.constant 5: index
1136 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1137 ///
1138 /// to
1139 ///
1140 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1141 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1143 
1144  LogicalResult matchAndRewrite(EmptyOp op,
1145  PatternRewriter &rewriter) const override {
1146  SmallVector<Value> foldedDynamicSizes;
1147  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1148  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1149 
1150  // Stop here if no dynamic size was promoted to static.
1151  if (foldedTensorType == op.getType())
1152  return failure();
1153 
1154  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1155  foldedDynamicSizes);
1156  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1157  return success();
1158  }
1159 };
1160 
1161 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1163 
1164  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1165  PatternRewriter &rewriter) const override {
1166  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1167  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1168  if (!emptyTensorOp || !maybeConstantIndex)
1169  return failure();
1170  auto emptyTensorType = emptyTensorOp.getType();
1171  if (*maybeConstantIndex < 0 ||
1172  *maybeConstantIndex >= emptyTensorType.getRank() ||
1173  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1174  return failure();
1175  rewriter.replaceOp(dimOp,
1176  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1177  return success();
1178  }
1179 };
1180 
1181 /// Canonicalize
1182 ///
1183 /// ```mlir
1184 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1185 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1186 /// ```
1187 ///
1188 /// into
1189 ///
1190 /// ```mlir
1191 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1192 /// ```
1193 ///
1194 /// This assumes the input program is correct in terms of its shape. So it is
1195 /// safe to assume that `%d0` is in fact 4.
1196 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1198 
1199  LogicalResult matchAndRewrite(CastOp castOp,
1200  PatternRewriter &rewriter) const override {
1201  if (!canFoldIntoProducerOp(castOp))
1202  return failure();
1203  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1204  if (!producer)
1205  return failure();
1206 
1207  auto resultType =
1208  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1209  ArrayRef<int64_t> resultShape = resultType.getShape();
1210  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1211  SmallVector<OpFoldResult> newMixedSizes;
1212  newMixedSizes.reserve(currMixedSizes.size());
1213  assert(resultShape.size() == currMixedSizes.size() &&
1214  "mismatch in result shape and sizes of empty op");
1215  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1216  int64_t newDim = std::get<0>(it);
1217  OpFoldResult currDim = std::get<1>(it);
1218  // Case 1: The empty tensor dim is static. Check that the tensor cast
1219  // result dim matches.
1220  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1221  if (ShapedType::isDynamic(newDim) ||
1222  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1223  // Something is off, the cast result shape cannot be more dynamic
1224  // than the empty tensor result shape (enforced by
1225  // `canFoldIntoProducer`). Abort for now.
1226  return rewriter.notifyMatchFailure(
1227  producer, "mismatch in static value of shape of empty tensor "
1228  "result and cast result");
1229  }
1230  newMixedSizes.push_back(attr);
1231  continue;
1232  }
1233 
1234  // Case 2 : The tensor cast shape is static, but empty tensor result
1235  // shape is dynamic.
1236  if (!ShapedType::isDynamic(newDim)) {
1237  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1238  continue;
1239  }
1240 
1241  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1242  // shape is dynamic. Use the dynamic value from the empty tensor op.
1243  newMixedSizes.push_back(currDim);
1244  }
1245 
1246  // TODO: Do not drop tensor encoding.
1247  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1248  resultType.getElementType());
1249  return success();
1250  }
1251 };
1252 
1253 } // namespace
1254 
1255 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1256  MLIRContext *context) {
1257  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1258  ReplaceEmptyTensorStaticShapeDims>(context);
1259 }
1260 
1261 //===----------------------------------------------------------------------===//
1262 // ExtractOp
1263 //===----------------------------------------------------------------------===//
1264 
1265 namespace {
1266 
1267 /// Canonicalizes the pattern of the form
1268 ///
1269 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1270 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1271 ///
1272 /// to
1273 ///
1274 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1275 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1277 
1278  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1279  PatternRewriter &rewriter) const final {
1280  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1281  if (!tensorCast)
1282  return failure();
1283  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1284  return failure();
1285  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1286  extract, tensorCast.getSource(), extract.getIndices());
1287  return success();
1288  }
1289 };
1290 
1291 } // namespace
1292 
1293 void ExtractOp::getAsmResultNames(
1294  function_ref<void(Value, StringRef)> setNameFn) {
1295  setNameFn(getResult(), "extracted");
1296 }
1297 
1298 LogicalResult ExtractOp::verify() {
1299  // Verify the # indices match if we have a ranked type.
1300  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1301  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1302  return emitOpError("incorrect number of indices for extract_element");
1303  return success();
1304 }
1305 
1306 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1307  if (Attribute tensor = adaptor.getTensor()) {
1308  // If this is a splat elements attribute, simply return the value.
1309  // All of the elements of a splat attribute are the same.
1310  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1311  return splatTensor.getSplatValue<Attribute>();
1312 
1313  // If this is a dense resource elements attribute, return.
1314  if (isa<DenseResourceElementsAttr>(tensor))
1315  return {};
1316  }
1317 
1318  // Collect the constant indices into the tensor.
1319  SmallVector<uint64_t, 8> indices;
1320  for (Attribute indice : adaptor.getIndices()) {
1321  if (!indice || !llvm::isa<IntegerAttr>(indice))
1322  return {};
1323  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1324  }
1325 
1326  // Fold extract(from_elements(...)).
1327  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1328  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1329  auto rank = tensorType.getRank();
1330  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1331  "rank mismatch");
1332  int flatIndex = 0;
1333  int stride = 1;
1334  for (int i = rank - 1; i >= 0; --i) {
1335  flatIndex += indices[i] * stride;
1336  stride *= tensorType.getDimSize(i);
1337  }
1338  // Prevent out of bounds accesses. This can happen in invalid code that
1339  // will never execute.
1340  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1341  flatIndex < 0)
1342  return {};
1343  return fromElementsOp.getElements()[flatIndex];
1344  }
1345 
1346  // If this is an elements attribute, query the value at the given indices.
1347  if (Attribute tensor = adaptor.getTensor()) {
1348  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1349  if (elementsAttr && elementsAttr.isValidIndex(indices))
1350  return elementsAttr.getValues<Attribute>()[indices];
1351  }
1352 
1353  return {};
1354 }
1355 
1356 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1357  MLIRContext *context) {
1358  results.add<ExtractFromTensorCast>(context);
1359 }
1360 
1361 //===----------------------------------------------------------------------===//
1362 // FromElementsOp
1363 //===----------------------------------------------------------------------===//
1364 
1365 void FromElementsOp::getAsmResultNames(
1366  function_ref<void(Value, StringRef)> setNameFn) {
1367  setNameFn(getResult(), "from_elements");
1368 }
1369 
1370 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1371  ValueRange elements) {
1372  assert(!elements.empty() && "expected at least one element");
1373  Type resultType = RankedTensorType::get(
1374  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1375  build(builder, result, resultType, elements);
1376 }
1377 
1378 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1379  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1380  return DenseElementsAttr::get(getType(), adaptor.getElements());
1381  return {};
1382 }
1383 
1384 namespace {
1385 
1386 // Pushes the index_casts that occur before extractions to after the extract.
1387 // This minimizes type conversion in some cases and enables the extract
1388 // canonicalizer. This changes:
1389 //
1390 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1391 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1392 //
1393 // to the following:
1394 //
1395 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1396 // %cast = arith.index_cast %extract : i32 to index
1397 //
1398 // to just %element.
1399 //
1400 // Consider expanding this to a template and handle all tensor cast
1401 // operations.
1402 struct ExtractElementFromIndexCast
1403  : public OpRewritePattern<tensor::ExtractOp> {
1405 
1406  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1407  PatternRewriter &rewriter) const final {
1408  Location loc = extract.getLoc();
1409  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1410  if (!indexCast)
1411  return failure();
1412 
1413  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1414 
1415  auto newExtract = rewriter.create<tensor::ExtractOp>(
1416  loc, elementTy, indexCast.getIn(), extract.getIndices());
1417 
1418  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1419  newExtract);
1420 
1421  return success();
1422  }
1423 };
1424 
1425 } // namespace
1426 
1427 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1428  MLIRContext *context) {
1429  results.add<ExtractElementFromIndexCast>(context);
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // GatherOp
1434 //===----------------------------------------------------------------------===//
1435 
1436 void GatherOp::getAsmResultNames(
1437  function_ref<void(Value, StringRef)> setNameFn) {
1438  setNameFn(getResult(), "gather");
1439 }
1440 
1441 /// Return the inferred result type for a gatherOp where:
1442 /// - sourceType is the type of the source tensor gathered from
1443 /// - indicesType is the type of the indices used to gather
1444 /// - gatherDims are the dims along which the gather occurs.
1445 /// Return a full rank or ranked-reduced variant of the type depending on
1446 /// the value of rankReduced.
1447 ///
1448 /// The leading dimensions of the index tensor give the result tensor its
1449 /// leading dimensions.
1450 /// The trailing dimensions of the result tensor are obtained from the source
1451 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1452 /// rankedReduced is false), or skipping them (otherwise).
1453 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1454  RankedTensorType indicesType,
1455  ArrayRef<int64_t> gatherDims,
1456  bool rankReduced) {
1457  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1458  resultShape.reserve(resultShape.size() + sourceType.getRank());
1459  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1460  if (llvm::binary_search(gatherDims, idx)) {
1461  if (!rankReduced)
1462  resultShape.push_back(1);
1463  continue;
1464  }
1465  resultShape.push_back(sourceType.getDimSize(idx));
1466  }
1467  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1468 }
1469 
1470 static LogicalResult
1472  ArrayRef<int64_t> indices, int64_t rank,
1473  StringRef gatherOrScatter, StringRef sourceOrDest) {
1474  if (dims.empty())
1475  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1476 
1477  int64_t numGatherDims = dims.size();
1478  if (numGatherDims > rank)
1479  return op->emitOpError(gatherOrScatter)
1480  << "_dims overflow " << sourceOrDest << " rank";
1481  if (indices.empty() || indices.back() != numGatherDims)
1482  return op->emitOpError(gatherOrScatter)
1483  << "_dims length must match the size of last dimension of indices";
1484  for (int64_t val : dims) {
1485  if (val < 0)
1486  return op->emitOpError(gatherOrScatter)
1487  << "_dims value must be non-negative";
1488  if (val >= rank)
1489  return op->emitOpError(gatherOrScatter)
1490  << "_dims value must be smaller than " << sourceOrDest << " rank";
1491  }
1492  for (int64_t i = 1; i < numGatherDims; ++i) {
1493  if (dims[i - 1] >= dims[i])
1494  return op->emitOpError(gatherOrScatter)
1495  << "_dims values must be strictly increasing";
1496  }
1497  return success();
1498 }
1499 
1500 LogicalResult GatherOp::verify() {
1501  int64_t sourceRank = getSourceType().getRank();
1502  ArrayRef<int64_t> gatherDims = getGatherDims();
1503  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1504  getIndicesType().getShape(), sourceRank,
1505  "gather", "source")))
1506  return failure();
1507 
1508  RankedTensorType expectedResultType = GatherOp::inferResultType(
1509  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1510  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1511  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1512  if (getResultType() != expectedResultType &&
1513  getResultType() != expectedRankReducedResultType) {
1514  return emitOpError("result type "
1515  "mismatch: "
1516  "expected ")
1517  << expectedResultType << " or its rank-reduced variant "
1518  << expectedRankReducedResultType << " (got: " << getResultType()
1519  << ")";
1520  }
1521 
1522  return success();
1523 }
1524 
1525 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1526  if (OpFoldResult reshapedSource = reshapeConstantSource(
1527  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1528  getResult().getType()))
1529  return reshapedSource;
1530  return {};
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // InsertOp
1535 //===----------------------------------------------------------------------===//
1536 
1537 void InsertOp::getAsmResultNames(
1538  function_ref<void(Value, StringRef)> setNameFn) {
1539  setNameFn(getResult(), "inserted");
1540 }
1541 
1542 LogicalResult InsertOp::verify() {
1543  // Verify the # indices match if we have a ranked type.
1544  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1545  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1546  return emitOpError("incorrect number of indices");
1547  return success();
1548 }
1549 
1550 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1551  Attribute scalar = adaptor.getScalar();
1552  Attribute dest = adaptor.getDest();
1553  if (scalar && dest)
1554  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1555  if (scalar == splatDest.getSplatValue<Attribute>())
1556  return dest;
1557  return {};
1558 }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // GenerateOp
1562 //===----------------------------------------------------------------------===//
1563 
1564 void GenerateOp::getAsmResultNames(
1565  function_ref<void(Value, StringRef)> setNameFn) {
1566  setNameFn(getResult(), "generated");
1567 }
1568 
1569 LogicalResult GenerateOp::reifyResultShapes(
1570  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1571  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1572  int idx = 0;
1573  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1574  if (getType().isDynamicDim(dim)) {
1575  reifiedReturnShapes[0][dim] = getOperand(idx++);
1576  } else {
1577  reifiedReturnShapes[0][dim] =
1578  builder.getIndexAttr(getType().getDimSize(dim));
1579  }
1580  }
1581  return success();
1582 }
1583 
1584 LogicalResult GenerateOp::verify() {
1585  // Ensure that the tensor type has as many dynamic dimensions as are
1586  // specified by the operands.
1587  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1588  if (getNumOperands() != resultType.getNumDynamicDims())
1589  return emitError("must have as many index operands as dynamic extents "
1590  "in the result type");
1591  return success();
1592 }
1593 
1594 LogicalResult GenerateOp::verifyRegions() {
1595  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1596  // Ensure that region arguments span the index space.
1597  if (!llvm::all_of(getBody().getArgumentTypes(),
1598  [](Type ty) { return ty.isIndex(); }))
1599  return emitError("all body arguments must be index");
1600  if (getBody().getNumArguments() != resultTy.getRank())
1601  return emitError("must have one body argument per input dimension");
1602 
1603  // Ensure that the region yields an element of the right type.
1604  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1605 
1606  if (yieldOp.getValue().getType() != resultTy.getElementType())
1607  return emitOpError(
1608  "body must be terminated with a `yield` operation of the tensor "
1609  "element type");
1610 
1611  return success();
1612 }
1613 
1614 void GenerateOp::build(
1615  OpBuilder &b, OperationState &result, Type resultTy,
1616  ValueRange dynamicExtents,
1617  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1618  build(b, result, resultTy, dynamicExtents);
1619 
1620  // Build and populate body.
1621  OpBuilder::InsertionGuard guard(b);
1622  Region *bodyRegion = result.regions.front().get();
1623  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1624  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1625  SmallVector<Location, 2> argumentLocs(rank, result.location);
1626  Block *bodyBlock =
1627  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1628  bodyBuilder(b, result.location, bodyBlock->getArguments());
1629 }
1630 
1631 namespace {
1632 
1633 /// Canonicalizes tensor.generate operations with a constant
1634 /// operand into the equivalent operation with the operand expressed in the
1635 /// result type, instead. We also insert a type cast to make sure that the
1636 /// resulting IR is still well-typed.
1637 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1639 
1640  LogicalResult matchAndRewrite(GenerateOp generateOp,
1641  PatternRewriter &rewriter) const final {
1642  SmallVector<Value> foldedDynamicSizes;
1643  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1644  generateOp.getType(), generateOp.getDynamicExtents(),
1645  foldedDynamicSizes);
1646 
1647  // Stop here if no dynamic size was promoted to static.
1648  if (foldedTensorType == generateOp.getType())
1649  return failure();
1650 
1651  auto loc = generateOp.getLoc();
1652  auto newOp =
1653  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1654  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1655  newOp.getBody().begin());
1656  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1657  generateOp.getType(), newOp);
1658  return success();
1659  }
1660 };
1661 
1662 /// Canonicalizes the pattern of the form
1663 ///
1664 /// %tensor = tensor.generate %x {
1665 /// ^bb0(%arg0: index):
1666 /// <computation>
1667 /// yield %1 : index
1668 /// } : tensor<?xindex>
1669 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1670 ///
1671 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1672 /// tensor.generate operation has no side-effects.
1673 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1675 
1676  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1677  PatternRewriter &rewriter) const final {
1678  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1679  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1680  return failure();
1681 
1682  IRMapping mapping;
1683  Block *body = &tensorFromElements.getBody().front();
1684  mapping.map(body->getArguments(), extract.getIndices());
1685  for (auto &op : body->without_terminator())
1686  rewriter.clone(op, mapping);
1687 
1688  auto yield = cast<YieldOp>(body->getTerminator());
1689 
1690  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1691  return success();
1692  }
1693 };
1694 
1695 } // namespace
1696 
1697 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1698  MLIRContext *context) {
1699  // TODO: Move extract pattern to tensor::ExtractOp.
1700  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1701 }
1702 
1703 //===----------------------------------------------------------------------===//
1704 // RankOp
1705 //===----------------------------------------------------------------------===//
1706 
1707 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1708  setNameFn(getResult(), "rank");
1709 }
1710 
1711 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1712  // Constant fold rank when the rank of the operand is known.
1713  auto type = getOperand().getType();
1714  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1715  if (shapedType && shapedType.hasRank())
1716  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1717  return IntegerAttr();
1718 }
1719 
1720 //===----------------------------------------------------------------------===//
1721 // ReshapeOp
1722 //===----------------------------------------------------------------------===//
1723 
1724 void ReshapeOp::getAsmResultNames(
1725  function_ref<void(Value, StringRef)> setNameFn) {
1726  setNameFn(getResult(), "reshape");
1727 }
1728 
1729 static int64_t getNumElements(ShapedType type) {
1730  int64_t numElements = 1;
1731  for (auto dim : type.getShape())
1732  numElements *= dim;
1733  return numElements;
1734 }
1735 
1736 LogicalResult ReshapeOp::verify() {
1737  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1738  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1739 
1740  if (operandType.getElementType() != resultType.getElementType())
1741  return emitOpError("element types of source and destination tensor "
1742  "types should be the same");
1743 
1744  int64_t shapeSize =
1745  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1746  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1747  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1748 
1749  if (resultRankedType) {
1750  if (operandRankedType && resultRankedType.hasStaticShape() &&
1751  operandRankedType.hasStaticShape()) {
1752  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1753  return emitOpError("source and destination tensor should have the "
1754  "same number of elements");
1755  }
1756  if (ShapedType::isDynamic(shapeSize))
1757  return emitOpError("cannot use shape operand with dynamic length to "
1758  "reshape to statically-ranked tensor type");
1759  if (shapeSize != resultRankedType.getRank())
1760  return emitOpError(
1761  "length of shape operand differs from the result's tensor rank");
1762  }
1763  return success();
1764 }
1765 
1766 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1767  if (OpFoldResult reshapedSource = reshapeConstantSource(
1768  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1769  getResult().getType()))
1770  return reshapedSource;
1771 
1772  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1773  // producer's input instead as the original tensor to reshape. This could
1774  // render such producer dead code.
1775  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1776  getSourceMutable().assign(reshapeOpProducer.getSource());
1777  return getResult();
1778  }
1779 
1780  auto source = getSource();
1781  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1782  auto resultTy = dyn_cast<RankedTensorType>(getType());
1783  if (!sourceTy || !resultTy || sourceTy != resultTy)
1784  return {};
1785 
1786  // If the source and result are both 1D tensors and have the same type, the
1787  // reshape has no effect, even if the tensor is dynamically shaped.
1788  if (sourceTy.getRank() == 1)
1789  return source;
1790 
1791  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1792  auto elements = fromElements.getElements();
1793  bool dynamicNoop =
1794  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1795  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1796  auto element = elements[id];
1797 
1798  if (auto cst = getConstantIntValue(element)) {
1799  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1800  continue;
1801  }
1802 
1803  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1804  dynamicNoop &= dimOp.getSource() == source;
1805 
1806  auto cst = getConstantIntValue(dimOp.getIndex());
1807  dynamicNoop &=
1808  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1809  continue;
1810  }
1811 
1812  dynamicNoop = false;
1813  break;
1814  }
1815 
1816  if (dynamicNoop)
1817  return source;
1818  }
1819 
1820  return {};
1821 }
1822 
1823 //===----------------------------------------------------------------------===//
1824 // Reassociative reshape ops
1825 //===----------------------------------------------------------------------===//
1826 
1827 void CollapseShapeOp::getAsmResultNames(
1828  function_ref<void(Value, StringRef)> setNameFn) {
1829  setNameFn(getResult(), "collapsed");
1830 }
1831 
1832 void ExpandShapeOp::getAsmResultNames(
1833  function_ref<void(Value, StringRef)> setNameFn) {
1834  setNameFn(getResult(), "expanded");
1835 }
1836 
1837 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1838  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1839  "invalid resultDim");
1840  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1841  if (llvm::is_contained(it.value(), resultDim))
1842  return it.index();
1843  llvm_unreachable("could not find reassociation group");
1844 }
1845 
1846 FailureOr<SmallVector<OpFoldResult>>
1847 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1848  RankedTensorType expandedType,
1849  ArrayRef<ReassociationIndices> reassociation,
1850  ArrayRef<OpFoldResult> inputShape) {
1851  std::optional<SmallVector<OpFoldResult>> outputShape =
1852  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1853  inputShape);
1854  if (!outputShape)
1855  return failure();
1856  return *outputShape;
1857 }
1858 
1859 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1860  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1861 }
1862 
1863 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1864  Type resultType, Value src,
1865  ArrayRef<ReassociationIndices> reassociation,
1866  ArrayRef<OpFoldResult> outputShape) {
1867  auto [staticOutputShape, dynamicOutputShape] =
1869  build(builder, result, cast<RankedTensorType>(resultType), src,
1870  getReassociationIndicesAttribute(builder, reassociation),
1871  dynamicOutputShape, staticOutputShape);
1872 }
1873 
1874 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1875  Type resultType, Value src,
1876  ArrayRef<ReassociationIndices> reassociation) {
1877  SmallVector<OpFoldResult> inputShape =
1878  getMixedSizes(builder, result.location, src);
1879  auto tensorResultTy = cast<RankedTensorType>(resultType);
1880  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1881  builder, result.location, tensorResultTy, reassociation, inputShape);
1882  SmallVector<OpFoldResult> outputShapeOrEmpty;
1883  if (succeeded(outputShape)) {
1884  outputShapeOrEmpty = *outputShape;
1885  }
1886  build(builder, result, tensorResultTy, src, reassociation,
1887  outputShapeOrEmpty);
1888 }
1889 
1890 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1891  return getSymbolLessAffineMaps(getReassociationExprs());
1892 }
1893 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1895  getReassociationIndices());
1896 }
1897 
1898 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1899  return getSymbolLessAffineMaps(getReassociationExprs());
1900 }
1901 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1903  getReassociationIndices());
1904 }
1905 
1906 RankedTensorType CollapseShapeOp::inferCollapsedType(
1907  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1908  return inferCollapsedType(
1910  type.getContext(), reassociation)));
1911 }
1912 
1913 /// Compute the RankedTensorType obtained by applying `reassociation` to
1914 /// `type`.
1915 RankedTensorType
1916 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1917  ArrayRef<AffineMap> reassociation) {
1918  auto shape = type.getShape();
1919  SmallVector<int64_t, 4> newShape;
1920  newShape.reserve(reassociation.size());
1921 
1922  // Use the fact that reassociation is valid to simplify the logic: only use
1923  // each map's rank.
1924  assert(isReassociationValid(reassociation) && "invalid reassociation");
1925  unsigned currentDim = 0;
1926  for (AffineMap m : reassociation) {
1927  unsigned dim = m.getNumResults();
1928  auto band = shape.slice(currentDim, dim);
1929  int64_t size = 1;
1930  if (llvm::is_contained(band, ShapedType::kDynamic))
1931  size = ShapedType::kDynamic;
1932  else
1933  for (unsigned d = 0; d < dim; ++d)
1934  size *= shape[currentDim + d];
1935  newShape.push_back(size);
1936  currentDim += dim;
1937  }
1938 
1939  return RankedTensorType::get(newShape, type.getElementType());
1940 }
1941 
1942 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1943  ArrayRef<ReassociationIndices> reassociation,
1944  ArrayRef<NamedAttribute> attrs) {
1945  auto resultType = inferCollapsedType(
1946  llvm::cast<RankedTensorType>(src.getType()),
1948  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1949  result.addAttribute(getReassociationAttrStrName(),
1950  getReassociationIndicesAttribute(b, reassociation));
1951  build(b, result, resultType, src, attrs);
1952 }
1953 
1954 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1955  TensorReshapeOp, ExpandShapeOp>::value>
1956 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1957  RankedTensorType expandedType,
1958  RankedTensorType collapsedType) {
1959  if (failed(
1960  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1961  return failure();
1962 
1963  auto maps = op.getReassociationMaps();
1964  RankedTensorType expectedType =
1965  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1966  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1967  return op.emitOpError("expected collapsed type to be ")
1968  << expectedType << ", but got " << collapsedType;
1969  return success();
1970 }
1971 
1972 LogicalResult ExpandShapeOp::verify() {
1973  auto srcType = getSrcType();
1974  auto resultType = getResultType();
1975 
1976  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1977  return emitOpError("expected number of static shape dims to be equal to "
1978  "the output rank (")
1979  << resultType.getRank() << ") but found "
1980  << getStaticOutputShape().size() << " inputs instead";
1981 
1982  if ((int64_t)getOutputShape().size() !=
1983  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1984  return emitOpError("mismatch in dynamic dims in output_shape and "
1985  "static_output_shape: static_output_shape has ")
1986  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1987  << " dynamic dims while output_shape has " << getOutputShape().size()
1988  << " values";
1989 
1990  return verifyTensorReshapeOp(*this, resultType, srcType);
1991 }
1992 
1993 LogicalResult CollapseShapeOp::verify() {
1994  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1995 }
1996 
1997 namespace {
1998 /// Reshape of a splat constant can be replaced with a constant of the result
1999 /// type.
2000 template <typename TensorReshapeOp>
2001 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2003  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2004  PatternRewriter &rewriter) const override {
2005  DenseElementsAttr attr;
2006  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2007  return failure();
2008  if (!attr || !attr.isSplat())
2009  return failure();
2011  reshapeOp.getResultType(), attr.getRawData());
2012  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2013  return success();
2014  }
2015 };
2016 
2017 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2018 template <typename TensorReshapeOp>
2019 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2020 public:
2022 
2023  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2024  PatternRewriter &rewriter) const override {
2025  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2026  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2027  return failure();
2028 
2029  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2030  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2031  return success();
2032  }
2033 };
2034 
2035 /// Reshape of a FromElements can be replaced with a FromElements of the
2036 /// result type
2037 template <typename TensorReshapeOp>
2038 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2040  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2041  PatternRewriter &rewriter) const override {
2042  auto fromElements =
2043  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2044  if (!fromElements)
2045  return failure();
2046 
2047  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2048 
2049  if (!shapedTy.hasStaticShape())
2050  return failure();
2051 
2052  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2053  fromElements.getElements());
2054  return success();
2055  }
2056 };
2057 
2058 // Fold CastOp into CollapseShapeOp when adding static information.
2059 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2061 
2062  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2063  PatternRewriter &rewriter) const override {
2064  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2065  if (!tensor::canFoldIntoConsumerOp(castOp))
2066  return failure();
2067 
2068  RankedTensorType srcType =
2069  llvm::cast<RankedTensorType>(castOp.getSource().getType());
2070  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2071  srcType, collapseShapeOp.getReassociationMaps());
2072 
2073  if (newResultType == collapseShapeOp.getResultType()) {
2074  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2075  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2076  });
2077  } else {
2078  auto newOp = rewriter.create<CollapseShapeOp>(
2079  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2080  collapseShapeOp.getReassociation());
2081  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2082  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2083  }
2084  return success();
2085  }
2086 };
2087 
2088 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2089 /// matching constant output_shape operands of the expand. This makes the
2090 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2091 /// propagated further.
2092 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2094 
2095  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2096  PatternRewriter &rewriter) const override {
2097  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2098  if (!canFoldIntoConsumerOp(castOp))
2099  return failure();
2100 
2101  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2103  expandOp.getReassociationIndices();
2104 
2105  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2106  SmallVector<Value> dynamicOutputShape;
2107  auto outputIt = expandOp.getOutputShape().begin();
2108 
2109  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2110  for (uint64_t outDim : innerReassoc) {
2111  if (!ShapedType::isDynamic(newOutputShape[outDim]))
2112  continue;
2113 
2114  // If the cast's src type is dynamic, don't infer any of the
2115  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2116  // least one of the expanded dimensions to be dynamic if the input is
2117  // dynamic.
2118  Value val = *outputIt;
2119  ++outputIt;
2120  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2121  dynamicOutputShape.push_back(val);
2122  continue;
2123  }
2124 
2125  APInt cst;
2126  if (matchPattern(val, m_ConstantInt(&cst))) {
2127  newOutputShape[outDim] = cst.getSExtValue();
2128  } else {
2129  dynamicOutputShape.push_back(val);
2130  }
2131  }
2132  }
2133 
2134  // Couldn't match any values, nothing to change
2135  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2136  return failure();
2137 
2138  // Calculate the input shape from the output
2139  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2140  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2141  for (auto outDim : reassoc[inDim]) {
2142  auto ofr = newOutputShape[outDim];
2143  if (ShapedType::isDynamic(ofr)) {
2144  newInputShape[inDim] = ShapedType::kDynamic;
2145  break;
2146  }
2147  newInputShape[inDim] *= ofr;
2148  }
2149  }
2150 
2151  SmallVector<OpFoldResult> outputOfr =
2152  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2153  auto inputType = RankedTensorType::get(
2154  newInputShape, expandOp.getSrcType().getElementType());
2155  auto outputType = RankedTensorType::get(
2156  newOutputShape, expandOp.getSrcType().getElementType());
2157  auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2158  expandOp.getSrc());
2159  auto newExpand = rewriter.create<ExpandShapeOp>(
2160  expandOp.getLoc(), outputType, inputCast.getResult(),
2161  expandOp.getReassociationIndices(), outputOfr);
2162  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2163  newExpand.getResult());
2164  return success();
2165  }
2166 };
2167 } // namespace
2168 
2169 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2170  MLIRContext *context) {
2171  results.add<
2174  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2175  FoldReshapeWithSplat<ExpandShapeOp>,
2176  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2177 }
2178 
2179 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2180  MLIRContext *context) {
2181  results.add<
2183  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2184  tensor::DimOp, RankedTensorType>,
2185  FoldReshapeWithConstant<CollapseShapeOp>,
2186  FoldReshapeWithSplat<CollapseShapeOp>,
2187  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2188  context);
2189 }
2190 
2191 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2192  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2193  adaptor.getOperands());
2194 }
2195 
2196 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2197  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2198  adaptor.getOperands());
2199 }
2200 
2201 //===----------------------------------------------------------------------===//
2202 // ExtractSliceOp
2203 //===----------------------------------------------------------------------===//
2204 
2205 void ExtractSliceOp::getAsmResultNames(
2206  function_ref<void(Value, StringRef)> setNameFn) {
2207  setNameFn(getResult(), "extracted_slice");
2208 }
2209 
2210 /// An extract_slice result type can be inferred, when it is not
2211 /// rank-reduced, from the source type and the static representation of
2212 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2213 RankedTensorType ExtractSliceOp::inferResultType(
2214  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2215  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2216  // An extract_slice op may specify only a leading subset of offset/sizes/
2217  // strides in which case we complete with offset=0, sizes from memref type
2218  // and strides=1.
2219  assert(static_cast<int64_t>(staticSizes.size()) ==
2220  sourceTensorType.getRank() &&
2221  "unexpected staticSizes not equal to rank of source");
2222  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2223  sourceTensorType.getEncoding());
2224 }
2225 
2226 RankedTensorType ExtractSliceOp::inferResultType(
2227  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2229  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2230  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2231  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2232  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2233  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2234  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2235  staticSizes, staticStrides);
2236 }
2237 
2238 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2239 /// number of sizes), drop as many size 1 as needed to produce an inferred
2240 /// type with the desired rank.
2241 ///
2242 /// Note that there may be multiple ways to compute this rank-reduced type:
2243 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2244 ///
2245 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2246 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2247  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2248  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2249  ArrayRef<int64_t> strides) {
2250  // Type inferred in the absence of rank-reducing behavior.
2251  auto inferredType = llvm::cast<RankedTensorType>(
2252  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2253  int rankDiff = inferredType.getRank() - desiredResultRank;
2254  if (rankDiff > 0) {
2255  auto shape = inferredType.getShape();
2256  llvm::SmallBitVector dimsToProject =
2257  getPositionsOfShapeOne(rankDiff, shape);
2258  SmallVector<int64_t> projectedShape;
2259  // Best effort rank-reducing: drop 1s in order.
2260  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2261  if (!dimsToProject.test(pos))
2262  projectedShape.push_back(shape[pos]);
2263  inferredType =
2264  RankedTensorType::get(projectedShape, inferredType.getElementType());
2265  }
2266  return inferredType;
2267 }
2268 
2269 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2270  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2272  ArrayRef<OpFoldResult> strides) {
2273  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2274  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2275  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2276  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2277  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2278  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2279  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2280  staticStrides);
2281 }
2282 
2283 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2284 /// result type. If the type passed is nullptr, it is inferred.
2285 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2286  RankedTensorType resultType, Value source,
2287  ArrayRef<OpFoldResult> offsets,
2288  ArrayRef<OpFoldResult> sizes,
2289  ArrayRef<OpFoldResult> strides,
2290  ArrayRef<NamedAttribute> attrs) {
2291  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2292  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2293  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2294  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2295  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2296  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2297  // Structuring implementation this way avoids duplication between builders.
2298  if (!resultType) {
2299  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2300  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2301  }
2302  result.addAttributes(attrs);
2303  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2304  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2305  b.getDenseI64ArrayAttr(staticSizes),
2306  b.getDenseI64ArrayAttr(staticStrides));
2307 }
2308 
2309 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2310 /// result type.
2311 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2312  ArrayRef<OpFoldResult> offsets,
2313  ArrayRef<OpFoldResult> sizes,
2314  ArrayRef<OpFoldResult> strides,
2315  ArrayRef<NamedAttribute> attrs) {
2316  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2317 }
2318 
2319 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2320 /// a Range vector.
2321 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2322  ArrayRef<Range> ranges,
2323  ArrayRef<NamedAttribute> attrs) {
2324  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2325  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2326 }
2327 
2328 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2329 /// the type passed is nullptr, it is inferred.
2330 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2331  RankedTensorType resultType, Value source,
2332  ValueRange offsets, ValueRange sizes,
2333  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2334  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2335  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2336  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2337  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2338  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2339  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2340  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2341 }
2342 
2343 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2344 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2345  ValueRange offsets, ValueRange sizes,
2346  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2347  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2348 }
2349 
2350 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2351  Operation *op,
2352  RankedTensorType expectedType) {
2353  switch (result) {
2355  return success();
2357  return op->emitError("expected rank to be smaller or equal to ")
2358  << "the other rank. ";
2360  return op->emitError("expected type to be ")
2361  << expectedType << " or a rank-reduced version. (size mismatch) ";
2363  return op->emitError("expected element type to be ")
2364  << expectedType.getElementType();
2365  default:
2366  llvm_unreachable("unexpected extract_slice op verification result");
2367  }
2368 }
2369 
2370 /// Verifier for ExtractSliceOp.
2371 LogicalResult ExtractSliceOp::verify() {
2372  RankedTensorType sourceType = getSourceType();
2373 
2374  // Verify result type against inferred type.
2375  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2376  sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2377  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2378  if (result != SliceVerificationResult::Success)
2379  return produceSliceErrorMsg(result, *this, expectedType);
2380 
2381  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2382  // to the source tensor.
2384  sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2385  getStaticStrides(), /*generateErrorMessage=*/true);
2386  if (!boundsResult.isValid)
2387  return getOperation()->emitError(boundsResult.errorMessage);
2388 
2389  return success();
2390 }
2391 
2392 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2394 }
2395 
2396 FailureOr<Value>
2397 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2398  ArrayRef<int64_t> desiredShape) {
2399  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2400  assert(sourceTensorType && "not a ranked tensor type");
2401  auto sourceShape = sourceTensorType.getShape();
2402  if (sourceShape.equals(desiredShape))
2403  return value;
2404  auto maybeRankReductionMask =
2405  mlir::computeRankReductionMask(sourceShape, desiredShape);
2406  if (!maybeRankReductionMask)
2407  return failure();
2409  b, loc, value,
2410  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2411 }
2412 
2413 LogicalResult ExtractSliceOp::reifyResultShapes(
2414  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2415  reifiedReturnShapes.resize(1);
2416  reifiedReturnShapes[0].reserve(getType().getRank());
2418  llvm::SmallBitVector droppedDims = getDroppedDims();
2419  for (const auto &size : enumerate(mixedSizes)) {
2420  if (droppedDims.test(size.index()))
2421  continue;
2422  reifiedReturnShapes[0].push_back(size.value());
2423  }
2424  return success();
2425 }
2426 
2427 namespace {
2428 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2429 /// This essentially pushes memref_cast past its consuming slice when
2430 /// `canFoldIntoConsumerOp` is true.
2431 ///
2432 /// Example:
2433 /// ```
2434 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2435 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2436 /// tensor<3x4xf32>
2437 /// ```
2438 /// is rewritten into:
2439 /// ```
2440 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2441 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2442 /// ```
2443 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2444 public:
2446 
2447  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2448  PatternRewriter &rewriter) const override {
2449  // Any constant operand, just return to let the constant folder kick in.
2450  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2451  return matchPattern(operand, matchConstantIndex());
2452  }))
2453  return failure();
2454 
2455  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2456  if (!castOp)
2457  return failure();
2458 
2459  if (!canFoldIntoConsumerOp(castOp))
2460  return failure();
2461 
2462  // Pattern does not apply if the produced op would not verify.
2464  cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2465  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2466  sliceOp.getStaticStrides());
2467  if (!sliceResult.isValid)
2468  return failure();
2469 
2470  // Create folded extract.
2471  Location loc = sliceOp.getLoc();
2472  Value newResult = rewriter.create<ExtractSliceOp>(
2473  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2474  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2475  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2476  rewriter.replaceOp(sliceOp, newResult);
2477  return success();
2478  }
2479 };
2480 
2481 /// Slice elements from `values` into `outValues`. `counts` represents the
2482 /// numbers of elements to stride in the original values for each dimension.
2483 /// The output values can be used to construct a DenseElementsAttr.
2484 template <typename IterTy, typename ElemTy>
2485 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2486  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2487  ArrayRef<int64_t> strides,
2488  llvm::SmallVectorImpl<ElemTy> *outValues) {
2489  assert(offsets.size() == sizes.size());
2490  assert(offsets.size() == strides.size());
2491  if (offsets.empty())
2492  return;
2493 
2494  int64_t offset = offsets.front();
2495  int64_t size = sizes.front();
2496  int64_t stride = strides.front();
2497  if (offsets.size() == 1) {
2498  for (int64_t i = 0; i < size; ++i, offset += stride)
2499  outValues->push_back(*(values + offset));
2500 
2501  return;
2502  }
2503 
2504  for (int64_t i = 0; i < size; ++i, offset += stride) {
2505  auto begin = values + offset * counts.front();
2506  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2507  offsets.drop_front(), sizes.drop_front(),
2508  strides.drop_front(), outValues);
2509  }
2510 }
2511 
2512 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2513 /// folded operation might introduce more constant data; Users can control
2514 /// their heuristics by the control function.
2515 class ConstantOpExtractSliceFolder final
2516  : public OpRewritePattern<ExtractSliceOp> {
2517 public:
2519 
2520  ConstantOpExtractSliceFolder(MLIRContext *context,
2522  : OpRewritePattern<ExtractSliceOp>(context),
2523  controlFn(std::move(controlFn)) {}
2524 
2525  LogicalResult matchAndRewrite(ExtractSliceOp op,
2526  PatternRewriter &rewriter) const override {
2527  DenseElementsAttr attr;
2528  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2529  return failure();
2530 
2531  // A constant splat is handled by fold().
2532  if (attr.isSplat())
2533  return failure();
2534 
2535  // Dynamic result shape is not supported.
2536  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2537  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2538  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2539  return failure();
2540 
2541  // Customized control over the folding.
2542  if (!controlFn(op))
2543  return failure();
2544 
2545  int64_t count = sourceType.getNumElements();
2546  if (count == 0)
2547  return failure();
2548 
2549  // Check if there are any dynamic parts, which are not supported.
2550  auto offsets = op.getStaticOffsets();
2551  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2552  return failure();
2553  auto sizes = op.getStaticSizes();
2554  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2555  return failure();
2556  auto strides = op.getStaticStrides();
2557  if (llvm::is_contained(strides, ShapedType::kDynamic))
2558  return failure();
2559 
2560  // Compute the stride for each dimension.
2561  SmallVector<int64_t> counts;
2562  ArrayRef<int64_t> shape = sourceType.getShape();
2563  counts.reserve(shape.size());
2564  for (int64_t v : shape) {
2565  count = count / v;
2566  counts.push_back(count);
2567  }
2568 
2569  // New attribute constructed by the sliced values.
2570  DenseElementsAttr newAttr;
2571 
2572  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2573  SmallVector<APInt> outValues;
2574  outValues.reserve(sourceType.getNumElements());
2575  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2576  elems.begin(), counts, offsets, sizes, strides, &outValues);
2577  newAttr = DenseElementsAttr::get(resultType, outValues);
2578  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2579  SmallVector<APFloat> outValues;
2580  outValues.reserve(sourceType.getNumElements());
2581  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2582  elems.begin(), counts, offsets, sizes, strides, &outValues);
2583  newAttr = DenseElementsAttr::get(resultType, outValues);
2584  }
2585 
2586  if (newAttr) {
2587  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2588  return success();
2589  }
2590 
2591  return failure();
2592  }
2593 
2594 private:
2595  /// This additionally controls whether the fold happens or not. Users can
2596  /// impose their heuristics in the function.
2598 };
2599 
2600 } // namespace
2601 
2604  const ControlConstantExtractSliceFusionFn &controlFn) {
2605  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2606 }
2607 
2608 /// Return the canonical type of the result of an extract_slice op.
2610  RankedTensorType operator()(ExtractSliceOp op,
2611  ArrayRef<OpFoldResult> mixedOffsets,
2612  ArrayRef<OpFoldResult> mixedSizes,
2613  ArrayRef<OpFoldResult> mixedStrides) {
2614  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2615  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2616  mixedStrides);
2617  }
2618 };
2619 
2620 /// A canonicalizer wrapper to replace ExtractSliceOps.
2622  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2623  ExtractSliceOp newOp) {
2624  Value replacement = newOp.getResult();
2625  if (replacement.getType() != op.getType())
2626  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2627  replacement);
2628  rewriter.replaceOp(op, replacement);
2629  }
2630 };
2631 
2632 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2633  MLIRContext *context) {
2634  results.add<
2637  ExtractSliceOpCastFolder>(context);
2638 }
2639 
2640 //
2641 static LogicalResult
2642 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2643  ShapedType shapedType) {
2644  OpBuilder b(op.getContext());
2645  for (OpFoldResult ofr : op.getMixedOffsets())
2646  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2647  return failure();
2648  // Rank-reducing noops only need to inspect the leading dimensions:
2649  // llvm::zip is appropriate.
2650  auto shape = shapedType.getShape();
2651  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2652  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2653  return failure();
2654  for (OpFoldResult ofr : op.getMixedStrides())
2655  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2656  return failure();
2657  return success();
2658 }
2659 
2660 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2661 /// slice, we can return the InsertSliceOp's source directly.
2662 // TODO: This only checks the immediate producer; extend to go up the
2663 // insert/extract chain if the slices are disjoint.
2664 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2665  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2666 
2667  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2668  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2669  insertOp.isSameAs(extractOp, isSame))
2670  return insertOp.getSource();
2671 
2672  return {};
2673 }
2674 
2675 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2676  if (OpFoldResult reshapedSource = reshapeConstantSource(
2677  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2678  getResult().getType()))
2679  return reshapedSource;
2680  if (getSourceType() == getType() &&
2682  return this->getSource();
2683  if (Value slice = foldExtractAfterInsertSlice(*this))
2684  return slice;
2685 
2686  return OpFoldResult();
2687 }
2688 
2690  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2691  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2692  unsigned rank = rankedTensorType.getRank();
2693  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2694  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2695  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2696  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2697  offsets, sizes, strides);
2698 }
2699 
2700 //===----------------------------------------------------------------------===//
2701 // InsertSliceOp
2702 //===----------------------------------------------------------------------===//
2703 
2704 void InsertSliceOp::getAsmResultNames(
2705  function_ref<void(Value, StringRef)> setNameFn) {
2706  setNameFn(getResult(), "inserted_slice");
2707 }
2708 
2709 // Build a InsertSliceOp with mixed static and dynamic entries.
2710 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2711  Value dest, ArrayRef<OpFoldResult> offsets,
2712  ArrayRef<OpFoldResult> sizes,
2713  ArrayRef<OpFoldResult> strides,
2714  ArrayRef<NamedAttribute> attrs) {
2715  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2716  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2717  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2718  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2719  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2720  result.addAttributes(attrs);
2721  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2722  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2723  b.getDenseI64ArrayAttr(staticSizes),
2724  b.getDenseI64ArrayAttr(staticStrides));
2725 }
2726 
2727 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2728 /// Range vector.
2729 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2730  Value dest, ArrayRef<Range> ranges,
2731  ArrayRef<NamedAttribute> attrs) {
2732  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2733  build(b, result, source, dest, offsets, sizes, strides, attrs);
2734 }
2735 
2736 // Build a InsertSliceOp with dynamic entries.
2737 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2738  Value dest, ValueRange offsets, ValueRange sizes,
2739  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2740  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2741  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2742  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2743  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2744  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2745  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2746  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2747 }
2748 
2749 /// Rank-reducing type verification for both InsertSliceOp and
2750 /// ParallelInsertSliceOp.
2752  RankedTensorType srcType, RankedTensorType dstType,
2753  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2754  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2755  // insert_slice is the inverse of extract_slice, use the same type
2756  // inference.
2757  RankedTensorType expected = ExtractSliceOp::inferResultType(
2758  dstType, staticOffsets, staticSizes, staticStrides);
2759  if (expectedType)
2760  *expectedType = expected;
2761  return isRankReducedType(expected, srcType);
2762 }
2763 
2764 /// Verifier for InsertSliceOp.
2765 LogicalResult InsertSliceOp::verify() {
2766  // Verify result type against inferred type.
2767  RankedTensorType expectedType;
2768  SliceVerificationResult result =
2769  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2770  getStaticSizes(), getStaticStrides(), &expectedType);
2771  if (result != SliceVerificationResult::Success)
2772  return produceSliceErrorMsg(result, *this, expectedType);
2773 
2774  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2775  // to the destination tensor.
2777  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2778  getStaticStrides(), /*generateErrorMessage=*/true);
2779  if (!boundsResult.isValid)
2780  return getOperation()->emitError(boundsResult.errorMessage);
2781 
2782  return success();
2783 }
2784 
2785 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2786 /// can mutate the second InsertSliceOp's destination to the first one's.
2787 ///
2788 /// Example:
2789 ///
2790 /// ```mlir
2791 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2792 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2793 /// ```
2794 ///
2795 /// folds into:
2796 ///
2797 /// ```mlir
2798 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2799 /// ```
2800 ///
2801 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2802 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2803  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2804 
2805  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2806  if (!prevInsertOp ||
2807  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2808  !prevInsertOp.isSameAs(insertOp, isSame))
2809  return failure();
2810 
2811  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2812  return success();
2813 }
2814 
2815 /// Folds round-trip extract/insert slice op pairs.
2816 /// Example:
2817 /// ```mlir
2818 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2819 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2820 /// ```
2821 /// can be folded into %val.
2822 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2823  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2824 
2825  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2826  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2827  !extractOp.isSameAs(insertOp, isSame))
2828  return nullptr;
2829 
2830  return extractOp.getSource();
2831 }
2832 
2833 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2834  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2835  getSourceType() == getType() &&
2837  return this->getSource();
2838  if (succeeded(foldInsertAfterInsertSlice(*this)))
2839  return getResult();
2840  if (auto result = foldInsertAfterExtractSlice(*this))
2841  return result;
2842  if (llvm::any_of(getMixedSizes(), isZeroInteger))
2843  return getDest();
2844  return OpFoldResult();
2845 }
2846 
2847 LogicalResult InsertSliceOp::reifyResultShapes(
2848  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2849  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2850  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2851  return success();
2852 }
2853 
2854 namespace {
2855 /// Pattern to rewrite a insert_slice op with constant arguments.
2856 ///
2857 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2858 template <typename InsertOpTy>
2859 class InsertSliceOpConstantArgumentFolder final
2860  : public OpRewritePattern<InsertOpTy> {
2861 public:
2863 
2864  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2865  PatternRewriter &rewriter) const override {
2866  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2867  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2868  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2869 
2870  // No constant operands were folded, just return;
2871  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2872  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2873  failed(foldDynamicStrideList(mixedStrides)))
2874  return failure();
2875 
2876  // Pattern does not apply if the produced op would not verify.
2877  SliceBoundsVerificationResult sliceResult =
2878  verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2879  mixedOffsets, mixedSizes, mixedStrides);
2880  if (!sliceResult.isValid)
2881  return failure();
2882 
2883  // Create the new op in canonical form.
2884  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2885  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2886  mixedOffsets, mixedSizes, mixedStrides);
2887  Value toInsert = insertSliceOp.getSource();
2888  if (sourceType != insertSliceOp.getSourceType()) {
2889  OpBuilder::InsertionGuard g(rewriter);
2890  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2891  // is that the insertion point is just before the ParallelCombiningOp in
2892  // the parallel case.
2893  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2894  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2895  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2896  sourceType, toInsert);
2897  }
2898  rewriter.replaceOpWithNewOp<InsertOpTy>(
2899  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2900  mixedSizes, mixedStrides);
2901  return success();
2902  }
2903 };
2904 
2905 /// Fold tensor_casts with insert_slice operations. If the source or
2906 /// destination tensor is a tensor_cast that removes static type information,
2907 /// the cast is folded into the insert_slice operation. E.g.:
2908 ///
2909 /// ```mlir
2910 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2911 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2912 /// ```
2913 ///
2914 /// folds into:
2915 ///
2916 /// ```mlir
2917 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2918 /// ```
2919 ///
2920 /// Note: When folding a cast on the destination tensor, the result of the
2921 /// insert_slice operation is casted to ensure that the type of the result did
2922 /// not change.
2923 ///
2924 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2925 template <typename InsertOpTy>
2926 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2928 
2929  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2930  PatternRewriter &rewriter) const override {
2931  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2932  return matchPattern(operand, matchConstantIndex());
2933  }))
2934  return failure();
2935 
2936  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2937  auto castOp = v.getDefiningOp<tensor::CastOp>();
2938  if (!castOp || !canFoldIntoConsumerOp(castOp))
2939  return std::nullopt;
2940  return castOp.getSource();
2941  };
2942  std::optional<Value> sourceCastSource =
2943  getSourceOfCastOp(insertSliceOp.getSource());
2944  std::optional<Value> destCastSource =
2945  getSourceOfCastOp(insertSliceOp.getDest());
2946  if (!sourceCastSource && !destCastSource)
2947  return failure();
2948 
2949  auto src =
2950  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2951  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2952  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2953  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2954  if (!srcType || !dstType)
2955  return failure();
2956 
2957  // The tensor.cast source could have additional static information not seen
2958  // in the insert slice op static sizes, so we ignore dynamic dims when
2959  // computing the rank reduction mask.
2960  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2961  auto rankReductionMask = computeRankReductionMask(
2962  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2963  if (!rankReductionMask.has_value())
2964  return failure();
2965  // Replace dimensions in the insert slice op with corresponding static dims
2966  // from the cast source type. If the insert slice sizes have static dims
2967  // that are not static in the tensor.cast source (i.e., when the cast op
2968  // casts a dynamic dim to static), the dim should not be replaced, and the
2969  // pattern will fail later in `verifyInsertSliceOp`.
2970  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2971  int64_t rankReducedIdx = 0;
2972  for (auto [idx, size] : enumerate(staticSizes)) {
2973  if (!rankReductionMask.value().contains(idx) &&
2974  !srcType.isDynamicDim(rankReducedIdx)) {
2975  mixedSizes[idx] = getAsIndexOpFoldResult(
2976  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2977  size = srcType.getDimSize(rankReducedIdx++);
2978  }
2979  }
2980 
2981  // Pattern does not apply if the produced op would not verify.
2982  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2983  staticSizes, insertSliceOp.getStaticStrides()) !=
2985  return failure();
2986  SliceBoundsVerificationResult sliceResult =
2987  verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
2988  mixedSizes, insertSliceOp.getMixedStrides());
2989  if (!sliceResult.isValid)
2990  return failure();
2991 
2992  Operation *replacement = rewriter.create<InsertOpTy>(
2993  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2994  mixedSizes, insertSliceOp.getMixedStrides());
2995 
2996  // In the parallel case there is no result and so nothing to cast.
2997  bool isParallelInsert =
2998  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2999  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3000  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
3001  insertSliceOp.getDestType(),
3002  replacement->getResult(0));
3003  }
3004  rewriter.replaceOp(insertSliceOp, replacement->getResults());
3005  return success();
3006  }
3007 };
3008 
3009 /// If additional static type information can be deduced from a insert_slice's
3010 /// size operands, insert an explicit cast of the op's source operand. This
3011 /// enables other canonicalization patterns that are matching for tensor_cast
3012 /// ops such as `ForOpTensorCastFolder` in SCF.
3013 ///
3014 /// Example:
3015 ///
3016 /// ```mlir
3017 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3018 /// : tensor<?x?xf32> into ...
3019 /// ```
3020 ///
3021 /// folds into:
3022 ///
3023 /// ```mlir
3024 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3025 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3026 /// : tensor<64x64xf32> into ...
3027 /// ```
3028 ///
3029 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3030 template <typename InsertOpTy>
3031 struct InsertSliceOpSourceCastInserter final
3032  : public OpRewritePattern<InsertOpTy> {
3034 
3035  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3036  PatternRewriter &rewriter) const override {
3037  RankedTensorType srcType = insertSliceOp.getSourceType();
3038  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3039  return failure();
3040  SmallVector<int64_t> newSrcShape(srcType.getShape());
3041  for (int64_t i = 0; i < srcType.getRank(); ++i) {
3042  if (std::optional<int64_t> constInt =
3043  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3044  // Bail on invalid IR.
3045  if (*constInt < 0)
3046  return failure();
3047  newSrcShape[i] = *constInt;
3048  }
3049  }
3050  if (!hasValidSizesOffsets(newSrcShape))
3051  return failure();
3052 
3053  RankedTensorType newSrcType = RankedTensorType::get(
3054  newSrcShape, srcType.getElementType(), srcType.getEncoding());
3055  if (srcType == newSrcType ||
3056  !preservesStaticInformation(srcType, newSrcType) ||
3057  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3058  return failure();
3059 
3060  // newSrcType is:
3061  // 1) Different from srcType.
3062  // 2) "More static" than srcType.
3063  // 3) Cast-compatible with srcType.
3064  // Insert the cast.
3065  OpBuilder::InsertionGuard g(rewriter);
3066  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3067  // that the insertion point is just before the ParallelCombiningOp in the
3068  // parallel case.
3069  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3070  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3071  Value cast = rewriter.create<tensor::CastOp>(
3072  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3073  rewriter.replaceOpWithNewOp<InsertOpTy>(
3074  insertSliceOp, cast, insertSliceOp.getDest(),
3075  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3076  insertSliceOp.getMixedStrides());
3077  return success();
3078  }
3079 };
3080 } // namespace
3081 
3082 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3083  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3084 }
3085 
3086 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3087  MLIRContext *context) {
3088  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3089  InsertSliceOpCastFolder<InsertSliceOp>,
3090  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3091 }
3092 
3094  Location loc,
3095  Value tensor,
3096  Value dest) {
3097  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3098  unsigned rank = rankedTensorType.getRank();
3099  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3100  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3101  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3102  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3103  sizes, strides);
3104 }
3105 
3106 //===----------------------------------------------------------------------===//
3107 // PadOp
3108 //===----------------------------------------------------------------------===//
3109 
3110 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3111  setNameFn(getResult(), "padded");
3112 }
3113 
3114 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3115 // supports optional types.
3116 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3117  Type typeToInfer, Type typeToInferFrom) {}
3118 
3119 ParseResult
3121  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3122  Type &typeToInfer, Type typeToInferFrom) {
3123  if (optOperand)
3124  typeToInfer = typeToInferFrom;
3125  return success();
3126 }
3127 
3128 LogicalResult PadOp::verify() {
3129  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3130  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3131  auto expectedType =
3132  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3133  if (!expectedType) {
3134  return emitError("failed to infer expectedType from sourceType ")
3135  << sourceType << ", specified resultType is " << resultType;
3136  }
3137  if (resultType.getRank() != expectedType.getRank()) {
3138  return emitError("specified type ")
3139  << resultType << " does not match the inferred type "
3140  << expectedType;
3141  }
3142  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3143  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3144  continue;
3145  if (expectedType.isDynamicDim(i))
3146  continue;
3147  return emitError("specified type ")
3148  << resultType << " does not match the inferred type "
3149  << expectedType;
3150  }
3151 
3152  return success();
3153 }
3154 
3155 LogicalResult PadOp::verifyRegions() {
3156  auto &region = getRegion();
3157  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3158  Block &block = region.front();
3159  if (block.getNumArguments() != rank)
3160  return emitError("expected the block to have ") << rank << " arguments";
3161 
3162  // Note: the number and type of yield values are checked in the YieldOp.
3163  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3164  if (!en.value().isIndex())
3165  return emitOpError("expected block argument ")
3166  << (en.index() + 1) << " to be an index";
3167  }
3168 
3169  // Ensure that the region yields an element of the right type.
3170  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3171  if (yieldOp.getValue().getType() !=
3172  llvm::cast<ShapedType>(getType()).getElementType())
3173  return emitOpError("expected yield type to match shape element type");
3174 
3175  return success();
3176 }
3177 
3178 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3179  ArrayRef<int64_t> staticLow,
3180  ArrayRef<int64_t> staticHigh,
3181  ArrayRef<int64_t> resultShape) {
3182  unsigned rank = sourceType.getRank();
3183  if (staticLow.size() != rank)
3184  return RankedTensorType();
3185  if (staticHigh.size() != rank)
3186  return RankedTensorType();
3187  if (!resultShape.empty() && resultShape.size() != rank)
3188  return RankedTensorType();
3189 
3190  SmallVector<int64_t, 4> inferredShape;
3191  for (auto i : llvm::seq<unsigned>(0, rank)) {
3192  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3193  staticHigh[i] == ShapedType::kDynamic) {
3194  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3195  : resultShape[i]);
3196  } else {
3197  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3198  assert((resultShape.empty() || size == resultShape[i] ||
3199  resultShape[i] == ShapedType::kDynamic) &&
3200  "mismatch between inferred shape and result shape");
3201  inferredShape.push_back(size);
3202  }
3203  }
3204 
3205  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3206 }
3207 
3208 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3209  Value source, ArrayRef<int64_t> staticLow,
3210  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3211  bool nofold, ArrayRef<NamedAttribute> attrs) {
3212  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3213  if (!resultType)
3214  resultType = inferResultType(sourceType, staticLow, staticHigh);
3215  result.addAttributes(attrs);
3216  build(b, result, resultType, source, low, high,
3217  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3218  nofold ? b.getUnitAttr() : UnitAttr());
3219 }
3220 
3221 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3222  Value source, ValueRange low, ValueRange high, bool nofold,
3223  ArrayRef<NamedAttribute> attrs) {
3224  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3225  unsigned rank = sourceType.getRank();
3226  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3227  build(b, result, resultType, source, staticVector, staticVector, low, high,
3228  nofold, attrs);
3229 }
3230 
3231 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3232  Value source, ArrayRef<OpFoldResult> low,
3233  ArrayRef<OpFoldResult> high, bool nofold,
3234  ArrayRef<NamedAttribute> attrs) {
3235  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3236  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3237  SmallVector<int64_t, 4> staticLow, staticHigh;
3238  // staticLow and staticHigh have full information of the padding config.
3239  // This will grow staticLow and staticHigh with 1 value. If the config is
3240  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3241  // value as well.
3242  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3243  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3244  if (!resultType) {
3245  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3246  }
3247  assert(llvm::isa<RankedTensorType>(resultType));
3248  result.addAttributes(attrs);
3249  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3250  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3251  nofold ? b.getUnitAttr() : UnitAttr());
3252 }
3253 
3254 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3255  Value source, ArrayRef<OpFoldResult> low,
3256  ArrayRef<OpFoldResult> high, Value constantPadValue,
3257  bool nofold, ArrayRef<NamedAttribute> attrs) {
3258  build(b, result, resultType, source, low, high, nofold, attrs);
3259 
3260  // Add a region and a block to yield the pad value.
3261  Region *region = result.regions[0].get();
3262  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3263  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3264  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3265 
3266  // `builder.createBlock` changes the insertion point within the block. Create
3267  // a guard to reset the insertion point of the builder after it is destroyed.
3268  OpBuilder::InsertionGuard guard(b);
3269  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3270  b.create<tensor::YieldOp>(result.location, constantPadValue);
3271 }
3272 
3273 llvm::SmallBitVector PadOp::getPaddedDims() {
3274  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3275  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3276  for (const auto &en : enumerate(paddingWidths))
3277  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3278  paddedDims.set(en.index());
3279  };
3280  extractPaddedDims(getMixedLowPad());
3281  extractPaddedDims(getMixedHighPad());
3282  return paddedDims;
3283 }
3284 
3285 namespace {
3286 // Folds tensor.pad when padding is static zeros and the attribute
3287 // doesn't request otherwise.
3288 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3290 
3291  LogicalResult matchAndRewrite(PadOp padTensorOp,
3292  PatternRewriter &rewriter) const override {
3293  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3294  return failure();
3295  if (padTensorOp.getNofold())
3296  return failure();
3297  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3298  padTensorOp, padTensorOp.getResult().getType(),
3299  padTensorOp.getSource());
3300  return success();
3301  }
3302 };
3303 
3304 // Fold CastOp into PadOp when adding static information.
3305 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3307 
3308  LogicalResult matchAndRewrite(PadOp padTensorOp,
3309  PatternRewriter &rewriter) const override {
3310  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3311  if (!tensor::canFoldIntoConsumerOp(castOp))
3312  return failure();
3313 
3314  auto newResultType = PadOp::inferResultType(
3315  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3316  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3317  padTensorOp.getResultType().getShape());
3318 
3319  if (newResultType == padTensorOp.getResultType()) {
3320  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3321  padTensorOp.getSourceMutable().assign(castOp.getSource());
3322  });
3323  } else {
3324  auto newOp = rewriter.create<PadOp>(
3325  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3326  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3327  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3328  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3329  IRMapping mapper;
3330  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3331 
3332  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3333  padTensorOp, padTensorOp.getResultType(), newOp);
3334  }
3335  return success();
3336  }
3337 };
3338 
3339 // Fold CastOp using the result of PadOp back into the latter if it adds
3340 // static information.
3341 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3343 
3344  LogicalResult matchAndRewrite(PadOp padTensorOp,
3345  PatternRewriter &rewriter) const override {
3346  if (!padTensorOp.getResult().hasOneUse())
3347  return failure();
3348  auto tensorCastOp =
3349  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3350  if (!tensorCastOp)
3351  return failure();
3352  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3353  tensorCastOp.getDest().getType()))
3354  return failure();
3355 
3356  auto replacementOp = rewriter.create<PadOp>(
3357  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3358  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3359  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3360  padTensorOp.getHigh(), padTensorOp.getNofold(),
3361  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3362  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3363 
3364  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3365  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3366  return success();
3367  }
3368 };
3369 
3370 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3371 /// different dimensions. The pattern applies if the following preconditions
3372 /// hold:
3373 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3374 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3375 /// 3) the tensor::PadOps perform only high-padding,
3376 /// 4) the tensor::PadOps have the same constant padding value,
3377 /// 5) the tensor::PadOps do not have common padding dimensions,
3378 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3379 /// zero-offset for every dimension.
3380 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3381 /// the
3382 /// padded source dimensions.
3383 ///
3384 /// Example:
3385 ///
3386 /// ```mlir
3387 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3388 /// : tensor<64x64xf32> to tensor<?x64xf32>
3389 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3390 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3391 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3392 /// : tensor<8x64xf32> to tensor<8x?xf32>
3393 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3394 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3395 /// ```
3396 ///
3397 /// folds into:
3398 ///
3399 /// ```mlir
3400 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3401 /// : tensor<64x64xf32> to tensor<?x?xf32>
3402 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3403 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3404 /// ```
3405 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3407 
3408  LogicalResult matchAndRewrite(PadOp padOp,
3409  PatternRewriter &rewriter) const override {
3410  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3411  if (!innerSliceOp)
3412  return failure();
3413  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3414  if (!outerPadOp || outerPadOp.getNofold())
3415  return failure();
3416  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3417  if (!outerSliceOp)
3418  return failure();
3419 
3420  // 1) Fail if the chain is rank-reducing.
3421  int64_t rank = padOp.getSourceType().getRank();
3422  if (outerSliceOp.getSourceType().getRank() != rank) {
3423  return rewriter.notifyMatchFailure(padOp,
3424  "cannot fold rank-reducing chain");
3425  }
3426 
3427  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3428  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3429  return rewriter.notifyMatchFailure(
3430  padOp, "cannot fold non-unit stride ExtractSliceOps");
3431  }
3432 
3433  // 3) Fail if the tensor::PadOps have non-zero low padding.
3434  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3435  return rewriter.notifyMatchFailure(padOp,
3436  "cannot fold PadOps with low padding");
3437  }
3438 
3439  // 4) Fail if the tensor::PadOps padding values do not match.
3440  Attribute innerAttr, outerAttr;
3441  Value innerValue = padOp.getConstantPaddingValue();
3442  Value outerValue = outerPadOp.getConstantPaddingValue();
3443  if (!innerValue || !outerValue ||
3444  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3445  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3446  innerAttr != outerAttr) {
3447  return rewriter.notifyMatchFailure(
3448  padOp, "cannot fold PadOps with different padding values");
3449  }
3450 
3451  // 5) Fail if a dimension is padded by both tensor::PadOps.
3452  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3453  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3454  if (innerDims.anyCommon(outerDims)) {
3455  return rewriter.notifyMatchFailure(
3456  padOp, "cannot fold PadOps with common padding dimensions");
3457  }
3458 
3459  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3460  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3461  // for every dimension, and use the offset the other pair. Fail if no
3462  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3463  // exists.
3464  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3465  for (auto en : enumerate(newOffsets)) {
3466  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3467  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3468  if (!innerDims.test(en.index()) &&
3469  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3470  en.value() = outerOffset;
3471  continue;
3472  }
3473  if (!outerDims.test(en.index()) &&
3474  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3475  en.value() = innerOffset;
3476  continue;
3477  }
3478  return rewriter.notifyMatchFailure(
3479  padOp, "cannot find zero-offset and zero-padding pair");
3480  }
3481 
3482  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3483  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3484  // outer tensor::PadOp and fail if the size of the inner
3485  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3486  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3487  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3488  for (auto en : enumerate(newSizes)) {
3489  if (!outerDims.test(en.index()))
3490  continue;
3491  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3492  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3493  assert(!ShapedType::isDynamic(sourceSize) &&
3494  "expected padded dimension to have a static size");
3495  if (getConstantIntValue(sliceSize) != sourceSize) {
3496  return rewriter.notifyMatchFailure(
3497  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3498  "match the size of the outer padding");
3499  }
3500  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3501  }
3502 
3503  // Combine the high paddings of the two tensor::PadOps.
3504  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3505  for (auto en : enumerate(newHighPad)) {
3506  if (innerDims.test(en.index()))
3507  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3508  if (outerDims.test(en.index()))
3509  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3510  }
3511 
3512  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3513  // the two paddings in one step.
3514  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3515  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3516  innerSliceOp.getMixedStrides());
3517  auto newPadOp = rewriter.create<PadOp>(
3518  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3519  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3520  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3521  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3522  newPadOp.getRegion().begin());
3523  rewriter.replaceOp(padOp, newPadOp.getResult());
3524  return success();
3525  }
3526 };
3527 
3528 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3530 
3531  LogicalResult matchAndRewrite(PadOp padTensorOp,
3532  PatternRewriter &rewriter) const override {
3533  Value input = padTensorOp.getSource();
3534  if (!llvm::isa<RankedTensorType>(input.getType()))
3535  return failure();
3536  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3537  auto inputRank = inputDims.size();
3538 
3539  auto oldResultType =
3540  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3541  if (!oldResultType)
3542  return failure();
3543 
3544  auto outputDims = oldResultType.getShape();
3545 
3546  // Extract the static info from the high and low operands.
3547  SmallVector<int64_t> constOperandsLow;
3548  SmallVector<Value> newLows;
3549  for (auto operand : padTensorOp.getLow()) {
3550  APSInt intOp;
3551  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3552  constOperandsLow.push_back(ShapedType::kDynamic);
3553  newLows.push_back(operand);
3554  continue;
3555  }
3556  constOperandsLow.push_back(intOp.getExtValue());
3557  }
3558  SmallVector<int64_t> constOperandsHigh;
3559  SmallVector<Value> newHighs;
3560  for (auto operand : padTensorOp.getHigh()) {
3561  APSInt intOp;
3562  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3563  constOperandsHigh.push_back(ShapedType::kDynamic);
3564  newHighs.push_back(operand);
3565  continue;
3566  }
3567  constOperandsHigh.push_back(intOp.getExtValue());
3568  }
3569 
3570  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3571  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3572 
3573  // Verify the op is well-formed.
3574  if (inputDims.size() != outputDims.size() ||
3575  inputDims.size() != constLow.size() ||
3576  inputDims.size() != constHigh.size())
3577  return failure();
3578 
3579  auto lowCount = 0;
3580  auto highCount = 0;
3581  for (size_t i = 0; i < inputRank; i++) {
3582  if (constLow[i] == ShapedType::kDynamic)
3583  constLow[i] = constOperandsLow[lowCount++];
3584  if (constHigh[i] == ShapedType::kDynamic)
3585  constHigh[i] = constOperandsHigh[highCount++];
3586  }
3587 
3588  auto staticLow = ArrayRef<int64_t>(constLow);
3589  auto staticHigh = ArrayRef<int64_t>(constHigh);
3590 
3591  // Calculate the output sizes with the static information.
3592  SmallVector<int64_t> newOutDims;
3593  for (size_t i = 0; i < inputRank; i++) {
3594  if (outputDims[i] == ShapedType::kDynamic) {
3595  newOutDims.push_back(
3596  (staticLow[i] == ShapedType::kDynamic ||
3597  staticHigh[i] == ShapedType::kDynamic ||
3598  inputDims[i] == ShapedType::kDynamic
3599  ? ShapedType::kDynamic
3600  : inputDims[i] + staticLow[i] + staticHigh[i]));
3601  } else {
3602  newOutDims.push_back(outputDims[i]);
3603  }
3604  }
3605 
3606  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3607  llvm::all_of(newOutDims,
3608  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3609  return failure();
3610 
3611  // Rewrite the op using the new static type.
3612  auto newResultType = RankedTensorType::get(
3613  newOutDims, padTensorOp.getType().getElementType());
3614  auto newOp = rewriter.create<PadOp>(
3615  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3616  newLows, newHighs, padTensorOp.getNofold(),
3617  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3618 
3619  IRMapping mapper;
3620  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3621  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3622  newOp);
3623 
3624  return success();
3625  }
3626 };
3627 
3628 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3629 ///
3630 /// Example:
3631 ///
3632 /// ```mlir
3633 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3634 /// tensor.yield %val
3635 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3636 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3637 /// tensor.yield %val
3638 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3639 /// ```
3640 ///
3641 /// folds into:
3642 ///
3643 /// ```mlir
3644 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3645 /// tensor.yield %val
3646 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3647 /// ```
3648 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3650 
3651  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3652  PatternRewriter &rewriter) const override {
3653  if (padOp.getNofold()) {
3654  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3655  }
3656 
3657  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3658  if (!producerPad || producerPad.getNofold()) {
3659  return rewriter.notifyMatchFailure(
3660  padOp, "producer is not a foldable tensor.pad op");
3661  }
3662 
3663  // Fail if the tensor::PadOps padding values do not match.
3664  Value consumerPadValue = padOp.getConstantPaddingValue();
3665  Value producerPadValue = producerPad.getConstantPaddingValue();
3666  if (!consumerPadValue || !producerPadValue ||
3667  consumerPadValue != producerPadValue) {
3668  return rewriter.notifyMatchFailure(
3669  padOp,
3670  "cannot fold PadOps with different or non-constant padding values");
3671  }
3672 
3673  Location loc = padOp.getLoc();
3674  AffineExpr d0, d1;
3675  bindDims(rewriter.getContext(), d0, d1);
3676 
3677  // Combine the low/high paddings of the two tensor::PadOps.
3678  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3679  ArrayRef<OpFoldResult> producerPaddings) {
3680  SmallVector<OpFoldResult> sumPaddings;
3681  for (auto [consumerIndex, producerIndex] :
3682  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3683  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3684  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3685  }
3686  return sumPaddings;
3687  };
3688 
3689  SmallVector<OpFoldResult> newHighPad =
3690  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3691  SmallVector<OpFoldResult> newLowPad =
3692  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3693 
3694  auto newPadOp = rewriter.create<tensor::PadOp>(
3695  padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3696  newLowPad, newHighPad, padOp.getNofold(),
3697  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3698  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3699  newPadOp.getRegion().begin());
3700  rewriter.replaceOp(padOp, newPadOp.getResult());
3701  return success();
3702  }
3703 };
3704 
3705 } // namespace
3706 
3707 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3708  MLIRContext *context) {
3709  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3710  FoldOrthogonalPaddings, FoldStaticPadding,
3711  FoldConsecutiveConstantPadding>(context);
3712 }
3713 
3714 /// Return the padding value of the PadOp if it constant. In this context,
3715 /// "constant" means an actual constant or "defined outside of the block".
3716 ///
3717 /// Values are considered constant in three cases:
3718 /// - A ConstantLike value.
3719 /// - A basic block argument from a different block.
3720 /// - A value defined outside of the block.
3721 ///
3722 /// If the padding value is not constant, an empty Value is returned.
3723 Value PadOp::getConstantPaddingValue() {
3724  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3725  if (!yieldOp)
3726  return {};
3727  Value padValue = yieldOp.getValue();
3728  // Check if yield value is a constant.
3729  if (matchPattern(padValue, m_Constant()))
3730  return padValue;
3731  // Check if yield value is defined inside the PadOp block.
3732  if (padValue.getParentBlock() == &getRegion().front())
3733  return {};
3734  // Else: Yield value defined outside of the PadOp block.
3735  return padValue;
3736 }
3737 
3738 OpFoldResult PadOp::fold(FoldAdaptor) {
3739  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3740  !getNofold())
3741  return getSource();
3742  return {};
3743 }
3744 
3745 //===----------------------------------------------------------------------===//
3746 // ParallelInsertSliceOp
3747 //===----------------------------------------------------------------------===//
3748 
3749 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3750  ParallelCombiningOpInterface parallelCombiningParent =
3751  getParallelCombiningParent();
3752  for (const auto &it :
3753  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3754  Operation &nextOp = it.value();
3755  if (&nextOp == getOperation())
3756  return parallelCombiningParent.getParentResult(it.index());
3757  }
3758  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3759 }
3760 
3761 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3762 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3763  Value source, Value dest,
3764  ArrayRef<OpFoldResult> offsets,
3765  ArrayRef<OpFoldResult> sizes,
3766  ArrayRef<OpFoldResult> strides,
3767  ArrayRef<NamedAttribute> attrs) {
3768  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3769  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3770  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3771  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3772  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3773  result.addAttributes(attrs);
3774  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3775  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3776  b.getDenseI64ArrayAttr(staticSizes),
3777  b.getDenseI64ArrayAttr(staticStrides));
3778 }
3779 
3780 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3781 /// packed into a Range vector.
3782 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3783  Value source, Value dest,
3784  ArrayRef<Range> ranges,
3785  ArrayRef<NamedAttribute> attrs) {
3786  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3787  build(b, result, source, dest, offsets, sizes, strides, attrs);
3788 }
3789 
3790 // Build a ParallelInsertSliceOp with dynamic entries.
3791 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3792  Value source, Value dest, ValueRange offsets,
3793  ValueRange sizes, ValueRange strides,
3794  ArrayRef<NamedAttribute> attrs) {
3795  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3796  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3797  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3798  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3799  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3800  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3801  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3802 }
3803 
3804 LogicalResult ParallelInsertSliceOp::verify() {
3805  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3806  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3807  << *(getOperation()->getParentOp());
3808 
3809  // Verify result type against inferred type.
3810  RankedTensorType expectedType;
3811  SliceVerificationResult result =
3812  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3813  getStaticSizes(), getStaticStrides(), &expectedType);
3814  if (result != SliceVerificationResult::Success)
3815  return produceSliceErrorMsg(result, *this, expectedType);
3816 
3817  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3818  // to the destination tensor.
3820  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3821  getStaticStrides(), /*generateErrorMessage=*/true);
3822  if (!boundsResult.isValid)
3823  return getOperation()->emitError(boundsResult.errorMessage);
3824 
3825  return success();
3826 }
3827 
3828 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3829  RewritePatternSet &results, MLIRContext *context) {
3830  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3831  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3832  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3833 }
3834 
3835 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3836  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3837 }
3838 
3839 //===----------------------------------------------------------------------===//
3840 // ScatterOp
3841 //===----------------------------------------------------------------------===//
3842 
3843 void ScatterOp::getAsmResultNames(
3844  function_ref<void(Value, StringRef)> setNameFn) {
3845  setNameFn(getResult(), "scatter");
3846 }
3847 
3848 LogicalResult ScatterOp::verify() {
3849  int64_t destRank = getDestType().getRank();
3850  ArrayRef<int64_t> scatterDims = getScatterDims();
3851  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3852  getIndicesType().getShape(), destRank,
3853  "scatter", "dest")))
3854  return failure();
3855 
3856  if (!getUnique())
3857  return emitOpError("requires 'unique' attribute to be set");
3858  // TODO: we could also check statically that there are fewer leading index
3859  // tensor dims than the dest dims. If this is not the case, the unique
3860  // attribute cannot be true.
3861 
3862  // Use the GatherOp::inferResultType on the `dest` type and verify the
3863  // expected type matches the source type.
3864  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3865  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3866  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3867  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3868  if (getSourceType() != expectedSourceType &&
3869  getSourceType() != expectedRankReducedSourceType) {
3870  return emitOpError("source type "
3871  "mismatch: "
3872  "expected ")
3873  << expectedSourceType << " or its rank-reduced variant "
3874  << expectedRankReducedSourceType << " (got: " << getSourceType()
3875  << ")";
3876  }
3877 
3878  return success();
3879 }
3880 
3881 //===----------------------------------------------------------------------===//
3882 // SplatOp
3883 //===----------------------------------------------------------------------===//
3884 
3885 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3886  Type aggregateType, ValueRange dynamicSizes) {
3887  build(builder, result, aggregateType, element, dynamicSizes);
3888 }
3889 
3890 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3891  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3892  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3893  build(builder, result, aggregateType, element, dynamicSizes);
3894 }
3895 
3896 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3897  ArrayRef<OpFoldResult> sizes) {
3898  SmallVector<int64_t> staticShape;
3899  SmallVector<Value> dynamicSizes;
3900  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3901  build(builder, result, element, staticShape, dynamicSizes);
3902 }
3903 
3904 void SplatOp::getAsmResultNames(
3905  function_ref<void(Value, StringRef)> setNameFn) {
3906  setNameFn(getResult(), "splat");
3907 }
3908 
3909 LogicalResult SplatOp::verify() {
3910  if (getType().getNumDynamicDims() != getDynamicSizes().size())
3911  return emitOpError("incorrect number of dynamic sizes, has ")
3912  << getDynamicSizes().size() << ", expected "
3913  << getType().getNumDynamicDims();
3914  return success();
3915 }
3916 
3917 LogicalResult
3919  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3920  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3921  unsigned ctr = 0;
3922  for (int64_t i = 0; i < getType().getRank(); ++i) {
3923  if (getType().isDynamicDim(i)) {
3924  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3925  } else {
3926  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3927  }
3928  }
3929  return success();
3930 }
3931 
3932 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3933  auto constOperand = adaptor.getInput();
3934  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3935  return {};
3936 
3937  // Do not fold if the splat is not statically shaped
3938  if (!getType().hasStaticShape())
3939  return {};
3940 
3941  // SplatElementsAttr::get treats single value for second arg as being a
3942  // splat.
3943  return SplatElementsAttr::get(getType(), {constOperand});
3944 }
3945 
3946 //===----------------------------------------------------------------------===//
3947 // Common Canonicalizers and Folders.
3948 //===----------------------------------------------------------------------===//
3949 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
3950  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
3951  // 2. Exclude DPS ops that are also LoopLike from this interface as they
3952  // might need special handling of attached regions.
3953  if (isa<InsertSliceOp>(op.getOperation()) ||
3954  isa<LoopLikeOpInterface>(op.getOperation()))
3955  return false;
3956 
3957  return hasFoldableTensorCastOperand(op);
3958 }
3959 
3960 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
3961 /// the `tensor.cast` has source that is more static than the consuming op.
3962 ///
3963 /// Example:
3964 /// ```mlir
3965 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3966 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
3967 /// ```
3968 ///
3969 /// folds into:
3970 ///
3971 /// ```mlir
3972 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
3973 /// ```
3974 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
3975 /// can add the pattern to their canonicalizers.
3977  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
3979  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
3980 
3981  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
3982  PatternRewriter &rewriter) const override {
3983 
3984  // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
3985  // for that instead.
3986  if (!foldTensorCastPrecondition(op) ||
3987  isa<linalg::RelayoutOpInterface>(*op))
3988  return failure();
3989 
3990  SmallVector<Type> newResultTypes(op->getResultTypes());
3991  SmallVector<Value> newOperands =
3992  getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
3993 
3994  // Clone op
3995  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
3996 
3997  SmallVector<Value, 4> replacements;
3998  replacements.reserve(newOp->getNumResults());
3999  for (auto [oldResult, newResult] :
4000  llvm::zip(op->getResults(), newOp->getResults())) {
4001  if (newResult.getType() != oldResult.getType()) {
4002  replacements.push_back(rewriter.create<tensor::CastOp>(
4003  op->getLoc(), oldResult.getType(), newResult));
4004  } else {
4005  replacements.push_back(newResult);
4006  }
4007  }
4008  rewriter.replaceOp(op, replacements);
4009 
4010  return success();
4011  }
4012 };
4013 
4014 //===----------------------------------------------------------------------===//
4015 // TensorDialect
4016 //===----------------------------------------------------------------------===//
4017 
4018 void TensorDialect::getCanonicalizationPatterns(
4019  RewritePatternSet &results) const {
4021 }
4022 
4023 //===----------------------------------------------------------------------===//
4024 // TableGen'd op method definitions
4025 //===----------------------------------------------------------------------===//
4026 
4027 #define GET_OP_CLASSES
4028 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:420
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1471
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2350
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3120
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2802
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2642
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2664
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1729
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2751
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:184
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3116
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:140
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2822
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1956
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:3949
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:51
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:214
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:225
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:391
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:360
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2602
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:353
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:369
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:322
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3093
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2689
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:128
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:79
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:3977
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:3981
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2621
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2622
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2609
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2610
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.