MLIR  22.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/MathExtras.h"
39 #include <optional>
40 
41 using namespace mlir;
42 using namespace mlir::tensor;
43 
44 /// Materialize a single constant operation from a given attribute value with
45 /// the desired resultant type.
47  Attribute value, Type type,
48  Location loc) {
49  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50  return op;
51  if (complex::ConstantOp::isBuildableWith(value, type))
52  return complex::ConstantOp::create(builder, loc, type,
53  llvm::cast<ArrayAttr>(value));
54  return nullptr;
55 }
56 
58  int64_t dim) {
59  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
60  if (tensorType.isDynamicDim(dim))
61  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62 
63  return builder.getIndexAttr(tensorType.getDimSize(dim));
64 }
65 
67  Location loc, Value value) {
68  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
70  for (int64_t i = 0; i < tensorType.getRank(); ++i)
71  result.push_back(getMixedSize(builder, loc, value, i));
72  return result;
73 }
74 
76  OpResult opResult) {
77  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78  assert(tensorType && "expected tensor type");
79 
80  // If the op has a destination, it implements DestinationStyleOpInterface and
81  // we can query the destination operand from that interface.
82  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83  if (destOp)
84  return destOp.getTiedOpOperand(opResult)->get();
85 
86  // Otherwise, create a new destination tensor with the same shape.
88  b.setInsertionPoint(opResult.getDefiningOp());
89 
90  // Compute sizes.
91  SmallVector<OpFoldResult> mixedSizes;
92  if (!tensorType.hasStaticShape()) {
93  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94  ReifiedRankedShapedTypeDims reifiedShapes;
95  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96  return failure();
97  mixedSizes = reifiedShapes[opResult.getResultNumber()];
98  } else {
99  // Static shape: Take static sizes directly.
100  for (int64_t sz : tensorType.getShape())
101  mixedSizes.push_back(b.getIndexAttr(sz));
102  }
103 
104  // Create empty tensor.
105  Value emptyTensor =
106  tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
107  return emptyTensor;
108 }
109 
111  Operation *op,
112  SmallVector<Value> &result) {
113  for (OpResult opResult : op->getResults()) {
114  if (llvm::isa<TensorType>(opResult.getType())) {
115  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116  if (failed(destination))
117  return failure();
118  result.push_back(*destination);
119  }
120  }
121  return success();
122 }
123 
125  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127  return rtp1.getShape() == rtp2.getShape() &&
128  rtp1.getElementType() == rtp2.getElementType();
129  return false;
130  }
131  return tp1 == tp2; // default implementation
132 }
133 
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
136 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137  ArrayRef<OpFoldResult> mixedSizes) {
138  llvm::SmallBitVector droppedDims(mixedSizes.size());
139  int64_t shapePos = reducedShape.size() - 1;
140 
141  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142  size_t idx = mixedSizes.size() - size.index() - 1;
143  // Rank-reduced dims must have a static unit dimension.
144  bool isStaticUnitSize =
145  isa<Attribute>(size.value()) &&
146  llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
147 
148  if (shapePos < 0) {
149  // There are no more dims in the reduced shape. All remaining sizes must
150  // be rank-reduced dims.
151  assert(isStaticUnitSize && "expected unit dim");
152  droppedDims.set(idx);
153  continue;
154  }
155 
156  // Dim is preserved if the size is not a static 1.
157  if (!isStaticUnitSize) {
158  --shapePos;
159  continue;
160  }
161 
162  // Dim is preserved if the reduced shape dim is also 1.
163  if (reducedShape[shapePos] == 1) {
164  --shapePos;
165  continue;
166  }
167 
168  // Otherwise: Dim is dropped.
169  droppedDims.set(idx);
170  }
171 
172  assert(shapePos < 0 && "dimension mismatch");
173  return droppedDims;
174 }
175 
176 /// Given a ranked tensor type and a range of values that defines its dynamic
177 /// dimension sizes, turn all dynamic sizes that have a constant value into
178 /// static dimension sizes.
179 static RankedTensorType
180 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181  SmallVector<Value> &foldedDynamicSizes) {
182  SmallVector<int64_t> staticShape(type.getShape());
183  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184  "incorrect number of dynamic sizes");
185 
186  // Compute new static and dynamic sizes.
187  unsigned ctr = 0;
188  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189  if (type.isDynamicDim(i)) {
190  Value dynamicSize = dynamicSizes[ctr++];
191  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192  if (cst.has_value()) {
193  // Dynamic size must be non-negative.
194  if (cst.value() < 0) {
195  foldedDynamicSizes.push_back(dynamicSize);
196  continue;
197  }
198  staticShape[i] = *cst;
199  } else {
200  foldedDynamicSizes.push_back(dynamicSize);
201  }
202  }
203  }
204 
205  return RankedTensorType::get(staticShape, type.getElementType(),
206  type.getEncoding());
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // BitcastOp
211 //===----------------------------------------------------------------------===//
212 
213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214  if (inputs.size() != 1 || outputs.size() != 1)
215  return false;
216  Type a = inputs.front(), b = outputs.front();
217  auto aT = dyn_cast<TensorType>(a);
218  auto bT = dyn_cast<TensorType>(b);
219  if (!aT || !bT)
220  return false;
221 
222  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223  return false;
224 
225  return succeeded(verifyCompatibleShape(aT, bT));
226 }
227 
228 namespace {
229 
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231 /// operation.
232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
234 
235  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236  PatternRewriter &rewriter) const final {
237  auto tensorBitcastOperand =
238  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239  if (!tensorBitcastOperand)
240  return failure();
241 
242  auto resultType = cast<TensorType>(tensorBitcast.getType());
243  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244  tensorBitcastOperand.getOperand());
245  return success();
246  }
247 };
248 
249 } // namespace
250 
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252  MLIRContext *context) {
253  results.add<ChainedTensorBitcast>(context);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // CastOp
258 //===----------------------------------------------------------------------===//
259 
260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261  setNameFn(getResult(), "cast");
262 }
263 
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
267  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269 
270  // Requires RankedTensorType.
271  if (!sourceType || !targetType)
272  return false;
273 
274  // Requires same elemental type.
275  if (sourceType.getElementType() != targetType.getElementType())
276  return false;
277 
278  // Requires same rank.
279  if (sourceType.getRank() != targetType.getRank())
280  return false;
281 
282  // Requires same encoding.
283  if (sourceType.getEncoding() != targetType.getEncoding())
284  return false;
285 
286  // If cast is towards more static sizes along any dimension, don't fold.
287  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288  if (ShapedType::isStatic(std::get<0>(t)) &&
289  ShapedType::isDynamic(std::get<1>(t)))
290  return false;
291  }
292 
293  return true;
294 }
295 
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
302 ///
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
306 ///
307 /// Example:
308 /// ```mlir
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
311 /// ```
312 ///
313 /// folds into:
314 ///
315 /// ```mlir
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
317 /// ```
319  if (!castOp)
320  return false;
321 
322  // Can fold if the source of cast has at least as much static information as
323  // its results.
324  return preservesStaticInformation(castOp.getType(),
325  castOp.getSource().getType());
326 }
327 
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicalization patterns with the `tensor.cast` op as the root, but
331 /// producer being from different dialects. Returns true when all conditions are
332 /// met:
333 /// 1. source and result and ranked tensors with same element type and rank.
334 /// 2. the result type has more static information than the source.
335 ///
336 /// Example:
337 /// ```mlir
338 /// %1 = producer ... : tensor<?x?xf32>
339 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
340 /// ```
341 ///
342 /// can be canonicalized to :
343 ///
344 /// ```mlir
345 /// %2 = producer ... : tensor<8x16xf32>
346 /// ```
347 /// Not all ops might be canonicalizable this way, but for those that can be,
348 /// this method provides a check that it is worth doing the canonicalization.
350  if (!castOp)
351  return false;
352  return preservesStaticInformation(castOp.getSource().getType(),
353  castOp.getType());
354 }
355 
357  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
358  if (llvm::isa<BlockArgument>(opOperand.get()))
359  return false;
360  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361  return castOp && canFoldIntoConsumerOp(castOp);
362  });
363 }
364 
366  DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
367  SmallVector<Value> newOperands;
368  newOperands.reserve(op->getNumOperands());
369 
370  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
371 
372  // Assumes that the result has dpsInits followed by nonDpsInits.
373  int64_t dpsInitIdx = 0;
374  for (OpOperand &opOperand : op->getOpOperands()) {
375  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
376  bool fold = canFoldIntoConsumerOp(tensorCastOp);
377  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378  if (op.isDpsInit(&opOperand) &&
379  !llvm::isa<MemRefType>(newOperands.back().getType()))
380  newResTy[dpsInitIdx++] = newOperands.back().getType();
381  }
382  return newOperands;
383 }
384 
385 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
386 /// that can be folded.
388  bool folded = false;
389  for (OpOperand &operand : op->getOpOperands()) {
390  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
391  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
392  operand.set(castOp.getOperand());
393  folded = true;
394  }
395  }
396  return success(folded);
397 }
398 
399 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
400  if (inputs.size() != 1 || outputs.size() != 1)
401  return false;
402  Type a = inputs.front(), b = outputs.front();
403  auto aT = llvm::dyn_cast<TensorType>(a);
404  auto bT = llvm::dyn_cast<TensorType>(b);
405  if (!aT || !bT)
406  return false;
407 
408  if (aT.getElementType() != bT.getElementType())
409  return false;
410 
411  return succeeded(verifyCompatibleShape(aT, bT));
412 }
413 
414 /// Compute a TensorType that has the joined shape knowledge of the two
415 /// given TensorTypes. The element types need to match.
417  assert(one.getElementType() == two.getElementType());
418 
419  if (!one.hasRank())
420  return two;
421  if (!two.hasRank())
422  return one;
423 
424  int64_t rank = one.getRank();
425  if (rank != two.getRank())
426  return {};
427 
429  join.reserve(rank);
430  for (int64_t i = 0; i < rank; ++i) {
431  if (one.isDynamicDim(i)) {
432  join.push_back(two.getDimSize(i));
433  continue;
434  }
435  if (two.isDynamicDim(i)) {
436  join.push_back(one.getDimSize(i));
437  continue;
438  }
439  if (one.getDimSize(i) != two.getDimSize(i))
440  return {};
441  join.push_back(one.getDimSize(i));
442  }
443  return RankedTensorType::get(join, one.getElementType());
444 }
445 
446 namespace {
447 
448 /// Replaces chains of two tensor.cast operations by a single tensor.cast
449 /// operation if doing so does not remove runtime constraints.
450 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
452 
453  LogicalResult matchAndRewrite(CastOp tensorCast,
454  PatternRewriter &rewriter) const final {
455  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
456 
457  if (!tensorCastOperand)
458  return failure();
459 
460  auto sourceType =
461  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
464 
465  // We can remove the intermediate cast if joining all three produces the
466  // same result as just joining the source and result shapes.
467  auto firstJoin =
468  joinShapes(joinShapes(sourceType, intermediateType), resultType);
469 
470  // The join might not exist if the cast sequence would fail at runtime.
471  if (!firstJoin)
472  return failure();
473 
474  // The newJoin always exists if the above join exists, it might just contain
475  // less information. If so, we cannot drop the intermediate cast, as doing
476  // so would remove runtime checks.
477  auto newJoin = joinShapes(sourceType, resultType);
478  if (firstJoin != newJoin)
479  return failure();
480 
481  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482  tensorCastOperand.getOperand());
483  return success();
484  }
485 };
486 
487 /// Fold tensor.cast into tesor.extract_slice producer.
488 /// Example:
489 /// ```
490 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
491 /// tensor<128x512xf32> to tensor<?x512xf32>
492 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
493 /// ```
494 /// ->
495 /// ```
496 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
497 /// tensor<128x512xf32> to tensor<16x512xf32>
498 /// ```
499 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
501 
502  LogicalResult matchAndRewrite(CastOp tensorCast,
503  PatternRewriter &rewriter) const final {
504  auto extractOperand =
505  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
506 
507  // Cannot fold cast to unranked tensor.
508  auto rankedResultType =
509  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510  if (!rankedResultType)
511  return failure();
512 
513  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
514  rankedResultType.getShape() ==
515  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
516  .getShape())
517  return failure();
518 
519  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
520  auto dimMask = computeRankReductionMask(
521  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
522  size_t dimIndex = 0;
523  for (size_t i = 0, e = sizes.size(); i < e; i++) {
524  if (dimMask && dimMask->count(i))
525  continue;
526  int64_t dim = rankedResultType.getShape()[dimIndex++];
527  if (ShapedType::isDynamic(dim))
528  continue;
529  sizes[i] = rewriter.getIndexAttr(dim);
530  }
531 
532  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533  tensorCast, rankedResultType, extractOperand.getSource(),
534  extractOperand.getMixedOffsets(), sizes,
535  extractOperand.getMixedStrides());
536  return success();
537  }
538 };
539 
540 } // namespace
541 
542 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
543  MLIRContext *context) {
544  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // ConcatOp
549 //===----------------------------------------------------------------------===//
550 
551 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
552  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
553  auto tensorTypes =
554  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
555  return llvm::cast<RankedTensorType>(type);
556  }));
557  int64_t concatRank = tensorTypes[0].getRank();
558 
559  // The concatenation dim must be in the range [0, rank).
560  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
561 
562  SmallVector<int64_t> sizes(concatRank);
563  for (int64_t i = 0, e = concatRank; i < e; ++i) {
564  if (i == dim)
565  continue;
566  SaturatedInteger size;
567  for (auto tensorType : tensorTypes)
568  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
569  sizes[i] = size.asInteger();
570  }
571  auto concatSize = SaturatedInteger::wrap(0);
572  for (auto tensorType : tensorTypes)
573  concatSize =
574  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
575  sizes[dim] = concatSize.asInteger();
576  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
577 }
578 
579 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
580  ValueRange inputs) {
581  FailureOr<RankedTensorType> resultType =
582  inferResultType(dim, inputs.getTypes());
583  assert(succeeded(resultType) && "failed to infer concatenation result type");
584  build(builder, result, *resultType, dim, inputs);
585 }
586 
587 LogicalResult ConcatOp::verify() {
588  if (getInputs().size() < 1)
589  return emitOpError("requires at least one input");
590 
592  for (auto input : getInputs())
593  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
594 
595  RankedTensorType resultType = getResultType();
596  int64_t resultRank = getRank();
597  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598  return type.getRank() != resultRank;
599  }))
600  return emitOpError("rank of concatenated inputs must match result rank");
601 
602  Type resultElementType = resultType.getElementType();
603  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604  return type.getElementType() != resultElementType;
605  }))
606  return emitOpError("inputs and result element type must match");
607 
608  int64_t dim = getDim();
609  if (dim >= resultRank)
610  return emitOpError("concatenation dim must be less than the tensor rank");
611 
612  SmallVector<int64_t> sizes(resultRank);
613  for (int64_t i = 0, e = resultRank; i < e; ++i) {
614  if (i == dim)
615  continue;
616  SaturatedInteger size;
617  for (auto tensorType : inputTypes) {
618  FailureOr<SaturatedInteger> maybeSize =
619  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
620  if (failed(maybeSize))
621  return emitOpError("static concatenation size mismatch along ")
622  << "non-concatenated dimension " << i;
623  size = *maybeSize;
624  }
625  sizes[i] = size.asInteger();
626  }
627  auto concatSize = SaturatedInteger::wrap(0);
628  for (auto tensorType : inputTypes)
629  concatSize =
630  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
631  sizes[dim] = concatSize.asInteger();
632  auto inferredResultType =
633  RankedTensorType::get(sizes, inputTypes[0].getElementType());
634 
635  for (auto [inferredSize, actualSize] :
636  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638  ShapedType::isDynamic(actualSize);
639  if (!hasDynamic && inferredSize != actualSize)
640  return emitOpError("result type ")
641  << resultType << "does not match inferred shape "
642  << inferredResultType << " static sizes";
643  }
644 
645  return success();
646 }
647 
648 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
649  size_t numInputs = getInputs().size();
650  uint64_t concatDim = getDim();
651 
653  inputShapes.reserve(numInputs);
654  SmallVector<OpFoldResult> concatOffsets;
655  concatOffsets.reserve(numInputs);
656  SmallVector<OpFoldResult> outputShape;
657 
658  AffineExpr addExpr =
659  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
660  OpFoldResult zero = builder.getIndexAttr(0);
661  Location loc = getLoc();
662  for (auto [index, input] : llvm::enumerate(getInputs())) {
663  SmallVector<OpFoldResult> inputShape =
664  tensor::getMixedSizes(builder, input.getLoc(), input);
665  if (index == 0) {
666  outputShape = inputShape;
667  concatOffsets.push_back(zero);
668  } else {
669  concatOffsets.push_back(outputShape[concatDim]);
670  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
671  builder, loc, addExpr,
672  {outputShape[concatDim], inputShape[concatDim]});
673  }
674  inputShapes.emplace_back(std::move(inputShape));
675  }
676 
677  Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
678  getType().getElementType());
679 
680  int64_t rank = getType().getRank();
681  OpFoldResult one = builder.getIndexAttr(1);
682  SmallVector<OpFoldResult> strides(rank, one);
683  SmallVector<OpFoldResult> offsets(rank, zero);
684  for (auto [index, input] : llvm::enumerate(getInputs())) {
685  offsets[concatDim] = concatOffsets[index];
686  auto insertSlice = tensor::InsertSliceOp::create(
687  builder, loc, input, replacement, offsets, inputShapes[index], strides);
688  replacement = insertSlice.getResult();
689  }
690  if (replacement.getType() != getType()) {
691  replacement = tensor::CastOp::create(builder, loc, getType(), replacement);
692  }
693  return SmallVector<Value>{replacement};
694 }
695 
696 LogicalResult
698  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
699  ValueRange inputs = getInputs();
700  int64_t dim = getDim();
701  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
702 
703  Value init = inputs[0];
704  int64_t rank = getType().getRank();
705 
706  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
707 
708  // Pre-populate the result sizes with as much static information as possible
709  // from the given result type, as well as the inferred result type, otherwise
710  // use the dim sizes from the first input.
711  for (int64_t i = 0; i < rank; ++i) {
712  if (i == dim)
713  continue;
714  if (!getType().isDynamicDim(i)) {
715  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
716  } else if (!inferredResultType.isDynamicDim(i)) {
717  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
718  builder, getLoc(),
719  builder.getIndexAttr(inferredResultType.getDimSize(i)));
720  } else {
721  reifiedReturnShapes[0][i] =
722  tensor::DimOp::create(builder, init.getLoc(), init, i).getResult();
723  }
724  }
725 
726  if (getType().isDynamicDim(dim)) {
727  // Take the sum of the input sizes along the concatenated dim.
728  AffineExpr sum = builder.getAffineDimExpr(0);
729  SmallVector<OpFoldResult> sizes = {
730  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
731  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732  sum = sum + builder.getAffineDimExpr(idx + 1);
733  sizes.push_back(
734  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
735  }
736  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
737  builder, getLoc(),
738  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
739  } else {
740  // If the result shape is static along the concatenated dim, use the static
741  // shape.
742  reifiedReturnShapes[0][dim] =
743  builder.getIndexAttr(getType().getDimSize(dim));
744  }
745  return success();
746 }
747 
748 void ConcatOp::getAsmResultNames(
749  function_ref<void(Value, StringRef)> setNameFn) {
750  setNameFn(getResult(), "concat");
751 }
752 
753 OpFoldResult ConcatOp::fold(FoldAdaptor) {
754  ValueRange inputs = getInputs();
755  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
756  return inputs[0];
757  return {};
758 }
759 
760 namespace {
761 /// Fold a concat op with a single input to a cast.
762 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
764 
765  LogicalResult matchAndRewrite(ConcatOp concatOp,
766  PatternRewriter &rewriter) const override {
767  if (concatOp.getInputs().size() != 1)
768  return failure();
769  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
770  concatOp.getInputs()[0]);
771  return success();
772  }
773 };
774 
775 /// Propagate static shapes into the operands of a `tensor.concat`.
776 ///
777 /// `tensor.concat` requires every operand to match on all dimensions except the
778 /// concatenation dimension. If one operand is already static in those
779 /// dimensions, the other operands may safely be refined to that same static
780 /// shape.
781 ///
782 /// Example:
783 ///
784 /// ```mlir
785 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
786 /// tensor<?x12xi32>
787 /// ```
788 /// ->
789 /// ```mlir
790 /// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
791 /// %2 = tensor.concat dim(0) %0, %cast :
792 /// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
793 /// ```
794 struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
796 
797  LogicalResult matchAndRewrite(ConcatOp concatOp,
798  PatternRewriter &rewriter) const override {
799  int64_t dim = concatOp.getDim();
800  RankedTensorType inferredResultType =
801  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802 
803  // Find operands for which a more static shape can be inferred.
804  LogicalResult matched = failure();
805  // Inferred operand shapes are identical in every dimension except the
806  // concatenation dimension.
807  SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808  for (auto [operandIdx, operandType] :
809  llvm::enumerate(concatOp->getOperandTypes())) {
810  // Compute inferred type for operand.
811  inferredOperandShape[dim] =
812  cast<RankedTensorType>(operandType).getDimSize(dim);
813  auto inferredOperandType = RankedTensorType::get(
814  inferredOperandShape, inferredResultType.getElementType());
815 
816  // Check if inferred type is more static.
817  if (!preservesStaticInformation(inferredOperandType, operandType)) {
818  matched = success();
819 
820  // Use refined operand type and create cast from original operand.
821  auto castOp =
822  CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823  concatOp.getOperand(operandIdx));
824  rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
825  concatOp->setOperand(operandIdx, castOp->getResult(0));
826  });
827  }
828  }
829 
830  return matched;
831  }
832 };
833 
834 // Ensure `tensor.concat`'s result type is at least as static as can be inferred
835 // from its operand types.
836 ///
837 /// Example:
838 /// ```mlir
839 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
840 /// tensor<?x?xi32>
841 /// ```
842 /// ->
843 /// ```mlir
844 /// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
845 /// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
846 /// tensor<?x?xi32>
847 /// ```
848 struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
850 
851  LogicalResult matchAndRewrite(ConcatOp concatOp,
852  PatternRewriter &rewriter) const override {
853  int64_t dim = concatOp.getDim();
854  RankedTensorType inferredResultType =
855  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
856 
857  // The result type should be at least as static as inferred result type.
858  if (preservesStaticInformation(inferredResultType,
859  concatOp.getResultType())) {
860  return failure();
861  }
862 
863  auto newConcatOp =
864  ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865  concatOp->getOperands());
866  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
867  newConcatOp);
868 
869  return success();
870  }
871 };
872 } // namespace
873 
874 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
875  MLIRContext *context) {
876  results
877  .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
878  context);
879 }
880 
881 //===----------------------------------------------------------------------===//
882 // DimOp
883 //===----------------------------------------------------------------------===//
884 
885 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886  setNameFn(getResult(), "dim");
887 }
888 
889 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890  int64_t index) {
891  auto loc = result.location;
892  Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893  build(builder, result, source, indexValue);
894 }
895 
896 std::optional<int64_t> DimOp::getConstantIndex() {
897  return getConstantIntValue(getIndex());
898 }
899 
900 Speculation::Speculatability DimOp::getSpeculatability() {
901  auto constantIndex = getConstantIndex();
902  if (!constantIndex)
904 
905  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
906  if (!rankedSourceType)
908 
909  if (rankedSourceType.getRank() <= constantIndex)
911 
913 }
914 
915 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916  SetIntLatticeFn setResultRange) {
917  setResultRange(getResult(),
918  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919 }
920 
921 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
922  // All forms of folding require a known index.
923  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
924  if (!index)
925  return {};
926 
927  // Folding for unranked types (UnrankedTensorType) is not supported.
928  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
929  if (!tensorType)
930  return {};
931 
932  // Out of bound indices produce undefined behavior but are still valid IR.
933  // Don't choke on them.
934  int64_t indexVal = index.getInt();
935  if (indexVal < 0 || indexVal >= tensorType.getRank())
936  return {};
937 
938  // Fold if the shape extent along the given index is known.
939  if (!tensorType.isDynamicDim(index.getInt())) {
940  Builder builder(getContext());
941  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
942  }
943 
944  Operation *definingOp = getSource().getDefiningOp();
945 
946  // Fold dim to the operand of tensor.generate.
947  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
948  auto resultType =
949  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950  // The case where the type encodes the size of the dimension is handled
951  // above.
952  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
953 
954  // Find the operand of the fromElements that corresponds to this index.
955  auto dynExtents = fromElements.getDynamicExtents().begin();
956  for (auto dim : resultType.getShape().take_front(index.getInt()))
957  if (ShapedType::isDynamic(dim))
958  dynExtents++;
959 
960  return Value{*dynExtents};
961  }
962 
963  // The size at the given index is now known to be a dynamic size.
964  unsigned unsignedIndex = index.getValue().getZExtValue();
965 
966  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
968  // `resolve-shaped-type-result-dims` pass.
969  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970  sliceOp.isDynamicSize(unsignedIndex)) {
971  return {sliceOp.getDynamicSize(unsignedIndex)};
972  }
973  }
974 
975  // dim(cast) -> dim
976  if (succeeded(foldTensorCast(*this)))
977  return getResult();
978 
979  return {};
980 }
981 
982 namespace {
983 /// Fold dim of a cast into the dim of the source of the tensor cast.
984 struct DimOfCastOp : public OpRewritePattern<DimOp> {
986 
987  LogicalResult matchAndRewrite(DimOp dimOp,
988  PatternRewriter &rewriter) const override {
989  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990  if (!castOp)
991  return failure();
992  Value newSource = castOp.getOperand();
993  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
994  return success();
995  }
996 };
997 
998 /// Fold dim of a destination passing style op into the dim of the corresponding
999 /// init.
1000 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1002 
1003  LogicalResult matchAndRewrite(DimOp dimOp,
1004  PatternRewriter &rewriter) const override {
1005  auto source = dimOp.getSource();
1006  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1007  if (!destOp)
1008  return failure();
1009 
1010  auto resultIndex = cast<OpResult>(source).getResultNumber();
1011  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012 
1013  rewriter.modifyOpInPlace(
1014  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1015  return success();
1016  }
1017 };
1018 
1019 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1020 /// operand.
1021 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1023 
1024  LogicalResult matchAndRewrite(DimOp dim,
1025  PatternRewriter &rewriter) const override {
1026  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1027 
1028  if (!reshape)
1029  return failure();
1030 
1031  // Since tensors are immutable we don't need to worry about where to place
1032  // the extract call
1033  rewriter.setInsertionPointAfter(dim);
1034  Location loc = dim.getLoc();
1035  Value extract =
1036  ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037  if (extract.getType() != dim.getType())
1038  extract =
1039  arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1040  rewriter.replaceOp(dim, extract);
1041  return success();
1042  }
1043 };
1044 } // namespace
1045 
1046 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1047  MLIRContext *context) {
1048  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // EmptyOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1056  ArrayRef<int64_t> staticShape, Type elementType,
1057  Attribute encoding) {
1058  assert(none_of(staticShape, ShapedType::isDynamic) &&
1059  "expected only static sizes");
1060  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1061 }
1062 
1063 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1064  ArrayRef<int64_t> staticShape, Type elementType,
1065  ValueRange dynamicSizes, Attribute encoding) {
1066  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067  build(builder, result, tensorType, dynamicSizes);
1068 }
1069 
1070 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1071  ArrayRef<OpFoldResult> sizes, Type elementType,
1072  Attribute encoding) {
1073  SmallVector<int64_t> staticShape;
1074  SmallVector<Value> dynamicSizes;
1075  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1076  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1077 }
1078 
1079 LogicalResult EmptyOp::verify() {
1080  if (getType().getNumDynamicDims() != getDynamicSizes().size())
1081  return emitOpError("incorrect number of dynamic sizes, has ")
1082  << getDynamicSizes().size() << ", expected "
1083  << getType().getNumDynamicDims();
1084  return success();
1085 }
1086 
1087 LogicalResult
1089  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1090  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1091  unsigned ctr = 0;
1092  for (int64_t i = 0; i < getType().getRank(); ++i) {
1093  if (getType().isDynamicDim(i)) {
1094  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1095  } else {
1096  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1097  }
1098  }
1099  return success();
1100 }
1101 
1102 Value EmptyOp::getDynamicSize(unsigned idx) {
1103  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1104  unsigned ctr = 0;
1105  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1106  if (getType().isDynamicDim(i))
1107  ++ctr;
1108  return getDynamicSizes()[ctr];
1109 }
1110 
1113  unsigned ctr = 0;
1114  OpBuilder b(getContext());
1115  for (int64_t i = 0; i < getType().getRank(); ++i) {
1116  if (getType().isDynamicDim(i)) {
1117  result.push_back(getDynamicSizes()[ctr++]);
1118  } else {
1119  result.push_back(b.getIndexAttr(getType().getShape()[i]));
1120  }
1121  }
1122  return result;
1123 }
1124 
1125 namespace {
1126 /// Change the type of the result of a `tensor.empty` by making the result
1127 /// type statically sized along dimensions that in the original operation were
1128 /// defined as dynamic, but the size was defined using a `constant` op. For
1129 /// example
1130 ///
1131 /// %c5 = arith.constant 5: index
1132 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1133 ///
1134 /// to
1135 ///
1136 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1137 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1139 
1140  LogicalResult matchAndRewrite(EmptyOp op,
1141  PatternRewriter &rewriter) const override {
1142  SmallVector<Value> foldedDynamicSizes;
1143  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1144  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1145 
1146  // Stop here if no dynamic size was promoted to static.
1147  if (foldedTensorType == op.getType())
1148  return failure();
1149 
1150  auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1151  foldedDynamicSizes);
1152  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1153  return success();
1154  }
1155 };
1156 
1157 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1159 
1160  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1161  PatternRewriter &rewriter) const override {
1162  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1163  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1164  if (!emptyTensorOp || !maybeConstantIndex)
1165  return failure();
1166  auto emptyTensorType = emptyTensorOp.getType();
1167  if (*maybeConstantIndex < 0 ||
1168  *maybeConstantIndex >= emptyTensorType.getRank() ||
1169  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1170  return failure();
1171  rewriter.replaceOp(dimOp,
1172  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1173  return success();
1174  }
1175 };
1176 
1177 /// Canonicalize
1178 ///
1179 /// ```mlir
1180 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1181 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1182 /// ```
1183 ///
1184 /// into
1185 ///
1186 /// ```mlir
1187 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1188 /// ```
1189 ///
1190 /// This assumes the input program is correct in terms of its shape. So it is
1191 /// safe to assume that `%d0` is in fact 4.
1192 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1194 
1195  LogicalResult matchAndRewrite(CastOp castOp,
1196  PatternRewriter &rewriter) const override {
1197  if (!canFoldIntoProducerOp(castOp))
1198  return failure();
1199  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1200  if (!producer)
1201  return failure();
1202 
1203  auto resultType =
1204  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1205  ArrayRef<int64_t> resultShape = resultType.getShape();
1206  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1207  SmallVector<OpFoldResult> newMixedSizes;
1208  newMixedSizes.reserve(currMixedSizes.size());
1209  assert(resultShape.size() == currMixedSizes.size() &&
1210  "mismatch in result shape and sizes of empty op");
1211  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1212  int64_t newDim = std::get<0>(it);
1213  OpFoldResult currDim = std::get<1>(it);
1214  // Case 1: The empty tensor dim is static. Check that the tensor cast
1215  // result dim matches.
1216  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1217  if (ShapedType::isDynamic(newDim) ||
1218  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1219  // Something is off, the cast result shape cannot be more dynamic
1220  // than the empty tensor result shape (enforced by
1221  // `canFoldIntoProducer`). Abort for now.
1222  return rewriter.notifyMatchFailure(
1223  producer, "mismatch in static value of shape of empty tensor "
1224  "result and cast result");
1225  }
1226  newMixedSizes.push_back(attr);
1227  continue;
1228  }
1229 
1230  // Case 2 : The tensor cast shape is static, but empty tensor result
1231  // shape is dynamic.
1232  if (ShapedType::isStatic(newDim)) {
1233  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1234  continue;
1235  }
1236 
1237  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1238  // shape is dynamic. Use the dynamic value from the empty tensor op.
1239  newMixedSizes.push_back(currDim);
1240  }
1241 
1242  // TODO: Do not drop tensor encoding.
1243  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1244  resultType.getElementType());
1245  return success();
1246  }
1247 };
1248 
1249 } // namespace
1250 
1251 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1252  MLIRContext *context) {
1253  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1254  ReplaceEmptyTensorStaticShapeDims>(context);
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // ExtractOp
1259 //===----------------------------------------------------------------------===//
1260 
1261 namespace {
1262 
1263 /// Canonicalizes the pattern of the form
1264 ///
1265 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1266 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1267 ///
1268 /// to
1269 ///
1270 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1271 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1273 
1274  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1275  PatternRewriter &rewriter) const final {
1276  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1277  if (!tensorCast)
1278  return failure();
1279  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1280  return failure();
1281  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1282  extract, tensorCast.getSource(), extract.getIndices());
1283  return success();
1284  }
1285 };
1286 
1287 /// Canonicalizes the pattern of the form
1288 ///
1289 /// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1290 /// tensor<12xf64>
1291 /// %extracted_element = tensor.extract %val[%c10] :
1292 /// tensor<12xf64>
1293 ///
1294 /// to
1295 ///
1296 /// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1297 struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1299 
1300  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1301  PatternRewriter &rewriter) const final {
1302  auto collapseOp =
1303  extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1304  if (!collapseOp)
1305  return failure();
1306  if (!collapseOp.getSrcType().hasStaticShape())
1307  return failure();
1308 
1309  auto sourceSizes = collapseOp.getSrcType().getShape();
1310 
1311  SmallVector<Value> indices(extractOp.getIndices().begin(),
1312  extractOp.getIndices().end());
1313  SmallVector<Value> sourceIndices;
1314  for (auto [index, group] :
1315  llvm::zip(indices, collapseOp.getReassociationIndices())) {
1316  assert(!group.empty() && "association indices groups cannot be empty");
1317  auto groupSize = group.size();
1318 
1319  if (groupSize == 1) {
1320  sourceIndices.push_back(index);
1321  continue;
1322  }
1323 
1324  SmallVector<int64_t> basis =
1325  llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1326  auto delinearize = affine::AffineDelinearizeIndexOp::create(
1327  rewriter, extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1328  llvm::append_range(sourceIndices, delinearize.getResults());
1329  }
1330  if (collapseOp.getReassociationIndices().empty()) {
1331  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1332  int64_t srcRank =
1333  cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1335  rewriter, extractOp.getLoc(), zeroAffineMap,
1337  for (int64_t i = 0; i < srcRank; i++) {
1338  sourceIndices.push_back(
1339  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1340  }
1341  }
1342 
1343  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1344  extractOp, collapseOp.getSrc(), sourceIndices);
1345  return success();
1346  }
1347 };
1348 
1349 } // namespace
1350 
1351 void ExtractOp::getAsmResultNames(
1352  function_ref<void(Value, StringRef)> setNameFn) {
1353  setNameFn(getResult(), "extracted");
1354 }
1355 
1356 LogicalResult ExtractOp::verify() {
1357  // Verify the # indices match if we have a ranked type.
1358  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1359  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1360  return emitOpError("incorrect number of indices for extract_element");
1361  return success();
1362 }
1363 
1364 /// If we have an ExtractOp consuming an InsertOp with the same
1365 /// indices, we can return the InsertOp's scalar directly.
1366 // TODO: This only checks the immediate producer; extend to go up the
1367 // insert/extract chain if the slices are disjoint.
1368 static Value foldExtractAfterInsert(ExtractOp extractOp) {
1369  auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1370 
1371  auto isSame = [](Value a, Value b) {
1372  return getAsOpFoldResult(a) == getAsOpFoldResult(b);
1373  };
1374  if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1375  llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1376  return insertOp.getScalar();
1377 
1378  return {};
1379 }
1380 
1381 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1382  if (Attribute tensor = adaptor.getTensor()) {
1383  // If this is a splat elements attribute, simply return the value.
1384  // All of the elements of a splat attribute are the same.
1385  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1386  return splatTensor.getSplatValue<Attribute>();
1387 
1388  // If this is a dense resource elements attribute, return.
1389  if (isa<DenseResourceElementsAttr>(tensor))
1390  return {};
1391  }
1392 
1393  // Collect the constant indices into the tensor.
1394  SmallVector<uint64_t, 8> indices;
1395  for (Attribute indice : adaptor.getIndices()) {
1396  if (!indice || !llvm::isa<IntegerAttr>(indice))
1397  return {};
1398  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1399  }
1400 
1401  // Fold extract(from_elements(...)).
1402  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1403  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1404  auto rank = tensorType.getRank();
1405  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1406  "rank mismatch");
1407  int flatIndex = 0;
1408  int stride = 1;
1409  for (int i = rank - 1; i >= 0; --i) {
1410  flatIndex += indices[i] * stride;
1411  stride *= tensorType.getDimSize(i);
1412  }
1413  // Prevent out of bounds accesses. This can happen in invalid code that
1414  // will never execute.
1415  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1416  flatIndex < 0)
1417  return {};
1418  return fromElementsOp.getElements()[flatIndex];
1419  }
1420 
1421  // If this is an elements attribute, query the value at the given indices.
1422  if (Attribute tensor = adaptor.getTensor()) {
1423  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1424  if (elementsAttr && elementsAttr.isValidIndex(indices))
1425  return elementsAttr.getValues<Attribute>()[indices];
1426  }
1427 
1428  if (Value result = foldExtractAfterInsert(*this))
1429  return result;
1430 
1431  return {};
1432 }
1433 
1434 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1435  MLIRContext *context) {
1436  results.add<ExtractFromTensorCast>(context);
1437 }
1438 
1441  patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1442 }
1443 
1444 //===----------------------------------------------------------------------===//
1445 // FromElementsOp
1446 //===----------------------------------------------------------------------===//
1447 
1448 void FromElementsOp::getAsmResultNames(
1449  function_ref<void(Value, StringRef)> setNameFn) {
1450  setNameFn(getResult(), "from_elements");
1451 }
1452 
1453 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1454  ValueRange elements) {
1455  assert(!elements.empty() && "expected at least one element");
1456  Type resultType = RankedTensorType::get(
1457  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1458  build(builder, result, resultType, elements);
1459 }
1460 
1461 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1462  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1463  return DenseElementsAttr::get(getType(), adaptor.getElements());
1464  return {};
1465 }
1466 
1467 namespace {
1468 
1469 // Pushes the index_casts that occur before extractions to after the extract.
1470 // This minimizes type conversion in some cases and enables the extract
1471 // canonicalizer. This changes:
1472 //
1473 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1474 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1475 //
1476 // to the following:
1477 //
1478 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1479 // %cast = arith.index_cast %extract : i32 to index
1480 //
1481 // to just %element.
1482 //
1483 // Consider expanding this to a template and handle all tensor cast
1484 // operations.
1485 struct ExtractElementFromIndexCast
1486  : public OpRewritePattern<tensor::ExtractOp> {
1488 
1489  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1490  PatternRewriter &rewriter) const final {
1491  Location loc = extract.getLoc();
1492  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1493  if (!indexCast)
1494  return failure();
1495 
1496  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1497 
1498  auto newExtract = tensor::ExtractOp::create(
1499  rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1500 
1501  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1502  newExtract);
1503 
1504  return success();
1505  }
1506 };
1507 
1508 } // namespace
1509 
1510 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1511  MLIRContext *context) {
1512  results.add<ExtractElementFromIndexCast>(context);
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // GatherOp
1517 //===----------------------------------------------------------------------===//
1518 
1519 void GatherOp::getAsmResultNames(
1520  function_ref<void(Value, StringRef)> setNameFn) {
1521  setNameFn(getResult(), "gather");
1522 }
1523 
1524 /// Return the inferred result type for a gatherOp where:
1525 /// - sourceType is the type of the source tensor gathered from
1526 /// - indicesType is the type of the indices used to gather
1527 /// - gatherDims are the dims along which the gather occurs.
1528 /// Return a full rank or ranked-reduced variant of the type depending on
1529 /// the value of rankReduced.
1530 ///
1531 /// The leading dimensions of the index tensor give the result tensor its
1532 /// leading dimensions.
1533 /// The trailing dimensions of the result tensor are obtained from the source
1534 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1535 /// rankedReduced is false), or skipping them (otherwise).
1536 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1537  RankedTensorType indicesType,
1538  ArrayRef<int64_t> gatherDims,
1539  bool rankReduced) {
1540  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1541  resultShape.reserve(resultShape.size() + sourceType.getRank());
1542  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1543  if (llvm::binary_search(gatherDims, idx)) {
1544  if (!rankReduced)
1545  resultShape.push_back(1);
1546  continue;
1547  }
1548  resultShape.push_back(sourceType.getDimSize(idx));
1549  }
1550  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1551 }
1552 
1553 static LogicalResult
1555  ArrayRef<int64_t> indices, int64_t rank,
1556  StringRef gatherOrScatter, StringRef sourceOrDest) {
1557  if (dims.empty())
1558  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1559 
1560  int64_t numGatherDims = dims.size();
1561  if (numGatherDims > rank)
1562  return op->emitOpError(gatherOrScatter)
1563  << "_dims overflow " << sourceOrDest << " rank";
1564  if (indices.empty() || indices.back() != numGatherDims)
1565  return op->emitOpError(gatherOrScatter)
1566  << "_dims length must match the size of last dimension of indices";
1567  for (int64_t val : dims) {
1568  if (val < 0)
1569  return op->emitOpError(gatherOrScatter)
1570  << "_dims value must be non-negative";
1571  if (val >= rank)
1572  return op->emitOpError(gatherOrScatter)
1573  << "_dims value must be smaller than " << sourceOrDest << " rank";
1574  }
1575  for (int64_t i = 1; i < numGatherDims; ++i) {
1576  if (dims[i - 1] >= dims[i])
1577  return op->emitOpError(gatherOrScatter)
1578  << "_dims values must be strictly increasing";
1579  }
1580  return success();
1581 }
1582 
1583 LogicalResult GatherOp::verify() {
1584  int64_t sourceRank = getSourceType().getRank();
1585  ArrayRef<int64_t> gatherDims = getGatherDims();
1586  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1587  getIndicesType().getShape(), sourceRank,
1588  "gather", "source")))
1589  return failure();
1590 
1591  RankedTensorType expectedResultType = GatherOp::inferResultType(
1592  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1593  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1594  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1595  if (getResultType() != expectedResultType &&
1596  getResultType() != expectedRankReducedResultType) {
1597  return emitOpError("result type "
1598  "mismatch: "
1599  "expected ")
1600  << expectedResultType << " or its rank-reduced variant "
1601  << expectedRankReducedResultType << " (got: " << getResultType()
1602  << ")";
1603  }
1604 
1605  return success();
1606 }
1607 
1608 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1609  if (OpFoldResult reshapedSource = reshapeConstantSource(
1610  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1611  getResult().getType()))
1612  return reshapedSource;
1613  return {};
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // InsertOp
1618 //===----------------------------------------------------------------------===//
1619 
1620 void InsertOp::getAsmResultNames(
1621  function_ref<void(Value, StringRef)> setNameFn) {
1622  setNameFn(getResult(), "inserted");
1623 }
1624 
1625 LogicalResult InsertOp::verify() {
1626  // Verify the # indices match if we have a ranked type.
1627  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1628  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1629  return emitOpError("incorrect number of indices");
1630  return success();
1631 }
1632 
1633 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1634  Attribute scalar = adaptor.getScalar();
1635  Attribute dest = adaptor.getDest();
1636  if (scalar && dest)
1637  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1638  if (scalar == splatDest.getSplatValue<Attribute>())
1639  return dest;
1640  return {};
1641 }
1642 
1643 //===----------------------------------------------------------------------===//
1644 // GenerateOp
1645 //===----------------------------------------------------------------------===//
1646 
1647 void GenerateOp::getAsmResultNames(
1648  function_ref<void(Value, StringRef)> setNameFn) {
1649  setNameFn(getResult(), "generated");
1650 }
1651 
1652 LogicalResult GenerateOp::reifyResultShapes(
1653  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1654  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1655  int idx = 0;
1656  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1657  if (getType().isDynamicDim(dim)) {
1658  reifiedReturnShapes[0][dim] = getOperand(idx++);
1659  } else {
1660  reifiedReturnShapes[0][dim] =
1661  builder.getIndexAttr(getType().getDimSize(dim));
1662  }
1663  }
1664  return success();
1665 }
1666 
1667 LogicalResult GenerateOp::verify() {
1668  // Ensure that the tensor type has as many dynamic dimensions as are
1669  // specified by the operands.
1670  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1671  if (getNumOperands() != resultType.getNumDynamicDims())
1672  return emitError("must have as many index operands as dynamic extents "
1673  "in the result type");
1674  return success();
1675 }
1676 
1677 LogicalResult GenerateOp::verifyRegions() {
1678  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1679  // Ensure that region arguments span the index space.
1680  if (!llvm::all_of(getBody().getArgumentTypes(),
1681  [](Type ty) { return ty.isIndex(); }))
1682  return emitError("all body arguments must be index");
1683  if (getBody().getNumArguments() != resultTy.getRank())
1684  return emitError("must have one body argument per input dimension");
1685 
1686  // Ensure that the region yields an element of the right type.
1687  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1688 
1689  if (yieldOp.getValue().getType() != resultTy.getElementType())
1690  return emitOpError(
1691  "body must be terminated with a `yield` operation of the tensor "
1692  "element type");
1693 
1694  return success();
1695 }
1696 
1697 void GenerateOp::build(
1698  OpBuilder &b, OperationState &result, Type resultTy,
1699  ValueRange dynamicExtents,
1700  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1701  build(b, result, resultTy, dynamicExtents);
1702 
1703  // Build and populate body.
1704  OpBuilder::InsertionGuard guard(b);
1705  Region *bodyRegion = result.regions.front().get();
1706  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1707  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1708  SmallVector<Location, 2> argumentLocs(rank, result.location);
1709  Block *bodyBlock =
1710  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1711  bodyBuilder(b, result.location, bodyBlock->getArguments());
1712 }
1713 
1714 namespace {
1715 
1716 /// Canonicalizes tensor.generate operations with a constant
1717 /// operand into the equivalent operation with the operand expressed in the
1718 /// result type, instead. We also insert a type cast to make sure that the
1719 /// resulting IR is still well-typed.
1720 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1722 
1723  LogicalResult matchAndRewrite(GenerateOp generateOp,
1724  PatternRewriter &rewriter) const final {
1725  SmallVector<Value> foldedDynamicSizes;
1726  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1727  generateOp.getType(), generateOp.getDynamicExtents(),
1728  foldedDynamicSizes);
1729 
1730  // Stop here if no dynamic size was promoted to static.
1731  if (foldedTensorType == generateOp.getType())
1732  return failure();
1733 
1734  auto loc = generateOp.getLoc();
1735  auto newOp =
1736  GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1737  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1738  newOp.getBody().begin());
1739  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1740  generateOp.getType(), newOp);
1741  return success();
1742  }
1743 };
1744 
1745 /// Canonicalizes the pattern of the form
1746 ///
1747 /// %tensor = tensor.generate %x {
1748 /// ^bb0(%arg0: index):
1749 /// <computation>
1750 /// yield %1 : index
1751 /// } : tensor<?xindex>
1752 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1753 ///
1754 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1755 /// tensor.generate operation has no side-effects.
1756 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1758 
1759  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1760  PatternRewriter &rewriter) const final {
1761  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1762  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1763  return failure();
1764 
1765  IRMapping mapping;
1766  Block *body = &tensorFromElements.getBody().front();
1767  mapping.map(body->getArguments(), extract.getIndices());
1768  for (auto &op : body->without_terminator())
1769  rewriter.clone(op, mapping);
1770 
1771  auto yield = cast<YieldOp>(body->getTerminator());
1772 
1773  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1774  return success();
1775  }
1776 };
1777 
1778 } // namespace
1779 
1780 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1781  MLIRContext *context) {
1782  // TODO: Move extract pattern to tensor::ExtractOp.
1783  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // RankOp
1788 //===----------------------------------------------------------------------===//
1789 
1790 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1791  setNameFn(getResult(), "rank");
1792 }
1793 
1794 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1795  // Constant fold rank when the rank of the operand is known.
1796  auto type = getOperand().getType();
1797  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1798  if (shapedType && shapedType.hasRank())
1799  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1800  return IntegerAttr();
1801 }
1802 
1803 //===----------------------------------------------------------------------===//
1804 // ReshapeOp
1805 //===----------------------------------------------------------------------===//
1806 
1807 void ReshapeOp::getAsmResultNames(
1808  function_ref<void(Value, StringRef)> setNameFn) {
1809  setNameFn(getResult(), "reshape");
1810 }
1811 
1812 static int64_t getNumElements(ShapedType type) {
1813  int64_t numElements = 1;
1814  for (auto dim : type.getShape())
1815  numElements *= dim;
1816  return numElements;
1817 }
1818 
1819 LogicalResult ReshapeOp::verify() {
1820  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1821  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1822 
1823  if (operandType.getElementType() != resultType.getElementType())
1824  return emitOpError("element types of source and destination tensor "
1825  "types should be the same");
1826 
1827  int64_t shapeSize =
1828  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1829  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1830  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1831 
1832  if (resultRankedType) {
1833  if (operandRankedType && resultRankedType.hasStaticShape() &&
1834  operandRankedType.hasStaticShape()) {
1835  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1836  return emitOpError("source and destination tensor should have the "
1837  "same number of elements");
1838  }
1839  if (ShapedType::isDynamic(shapeSize))
1840  return emitOpError("cannot use shape operand with dynamic length to "
1841  "reshape to statically-ranked tensor type");
1842  if (shapeSize != resultRankedType.getRank())
1843  return emitOpError(
1844  "length of shape operand differs from the result's tensor rank");
1845  }
1846  return success();
1847 }
1848 
1849 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1850  if (OpFoldResult reshapedSource = reshapeConstantSource(
1851  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1852  getResult().getType()))
1853  return reshapedSource;
1854 
1855  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1856  // producer's input instead as the original tensor to reshape. This could
1857  // render such producer dead code.
1858  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1859  getSourceMutable().assign(reshapeOpProducer.getSource());
1860  return getResult();
1861  }
1862 
1863  auto source = getSource();
1864  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1865  auto resultTy = dyn_cast<RankedTensorType>(getType());
1866  if (!sourceTy || !resultTy || sourceTy != resultTy)
1867  return {};
1868 
1869  // If the source and result are both 0D or 1D tensors and have the same type,
1870  // the reshape has no effect, even if the tensor is dynamically shaped.
1871  if (sourceTy.getRank() <= 1)
1872  return source;
1873 
1874  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1875  auto elements = fromElements.getElements();
1876  bool dynamicNoop =
1877  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1878  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1879  auto element = elements[id];
1880 
1881  if (auto cst = getConstantIntValue(element)) {
1882  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1883  continue;
1884  }
1885 
1886  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1887  dynamicNoop &= dimOp.getSource() == source;
1888 
1889  auto cst = getConstantIntValue(dimOp.getIndex());
1890  dynamicNoop &=
1891  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1892  continue;
1893  }
1894 
1895  dynamicNoop = false;
1896  break;
1897  }
1898 
1899  if (dynamicNoop)
1900  return source;
1901  }
1902 
1903  return {};
1904 }
1905 
1906 //===----------------------------------------------------------------------===//
1907 // Reassociative reshape ops
1908 //===----------------------------------------------------------------------===//
1909 
1910 void CollapseShapeOp::getAsmResultNames(
1911  function_ref<void(Value, StringRef)> setNameFn) {
1912  setNameFn(getResult(), "collapsed");
1913 }
1914 
1915 void ExpandShapeOp::getAsmResultNames(
1916  function_ref<void(Value, StringRef)> setNameFn) {
1917  setNameFn(getResult(), "expanded");
1918 }
1919 
1920 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1921  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1922  "invalid resultDim");
1923  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1924  if (llvm::is_contained(it.value(), resultDim))
1925  return it.index();
1926  llvm_unreachable("could not find reassociation group");
1927 }
1928 
1929 FailureOr<SmallVector<OpFoldResult>>
1930 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1931  RankedTensorType expandedType,
1932  ArrayRef<ReassociationIndices> reassociation,
1933  ArrayRef<OpFoldResult> inputShape) {
1934  std::optional<SmallVector<OpFoldResult>> outputShape =
1935  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1936  inputShape);
1937  if (!outputShape)
1938  return failure();
1939  return *outputShape;
1940 }
1941 
1942 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1943  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1944 }
1945 
1946 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1947  Type resultType, Value src,
1948  ArrayRef<ReassociationIndices> reassociation,
1949  ArrayRef<OpFoldResult> outputShape) {
1950  auto [staticOutputShape, dynamicOutputShape] =
1952  build(builder, result, cast<RankedTensorType>(resultType), src,
1953  getReassociationIndicesAttribute(builder, reassociation),
1954  dynamicOutputShape, staticOutputShape);
1955 }
1956 
1957 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1958  Type resultType, Value src,
1959  ArrayRef<ReassociationIndices> reassociation) {
1960  SmallVector<OpFoldResult> inputShape =
1961  getMixedSizes(builder, result.location, src);
1962  auto tensorResultTy = cast<RankedTensorType>(resultType);
1963  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1964  builder, result.location, tensorResultTy, reassociation, inputShape);
1965  SmallVector<OpFoldResult> outputShapeOrEmpty;
1966  if (succeeded(outputShape)) {
1967  outputShapeOrEmpty = *outputShape;
1968  }
1969  build(builder, result, tensorResultTy, src, reassociation,
1970  outputShapeOrEmpty);
1971 }
1972 
1973 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1974  return getSymbolLessAffineMaps(getReassociationExprs());
1975 }
1976 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1978  getReassociationIndices());
1979 }
1980 
1981 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1982  return getSymbolLessAffineMaps(getReassociationExprs());
1983 }
1984 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1986  getReassociationIndices());
1987 }
1988 
1989 RankedTensorType CollapseShapeOp::inferCollapsedType(
1990  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1991  return inferCollapsedType(
1993  type.getContext(), reassociation)));
1994 }
1995 
1996 /// Compute the RankedTensorType obtained by applying `reassociation` to
1997 /// `type`.
1998 RankedTensorType
1999 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2000  ArrayRef<AffineMap> reassociation) {
2001  auto shape = type.getShape();
2002  SmallVector<int64_t, 4> newShape;
2003  newShape.reserve(reassociation.size());
2004 
2005  // Use the fact that reassociation is valid to simplify the logic: only use
2006  // each map's rank.
2007  assert(isReassociationValid(reassociation) && "invalid reassociation");
2008  unsigned currentDim = 0;
2009  for (AffineMap m : reassociation) {
2010  unsigned dim = m.getNumResults();
2011  auto band = shape.slice(currentDim, dim);
2012  int64_t size = 1;
2013  if (llvm::is_contained(band, ShapedType::kDynamic))
2014  size = ShapedType::kDynamic;
2015  else
2016  for (unsigned d = 0; d < dim; ++d)
2017  size *= shape[currentDim + d];
2018  newShape.push_back(size);
2019  currentDim += dim;
2020  }
2021 
2022  return RankedTensorType::get(newShape, type.getElementType());
2023 }
2024 
2025 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2026  ArrayRef<ReassociationIndices> reassociation,
2027  ArrayRef<NamedAttribute> attrs) {
2028  auto resultType = inferCollapsedType(
2029  llvm::cast<RankedTensorType>(src.getType()),
2031  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2032  result.addAttribute(getReassociationAttrStrName(),
2033  getReassociationIndicesAttribute(b, reassociation));
2034  build(b, result, resultType, src, attrs);
2035 }
2036 
2037 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2038  TensorReshapeOp, ExpandShapeOp>::value>
2039 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2040  RankedTensorType expandedType,
2041  RankedTensorType collapsedType) {
2042  if (failed(
2043  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2044  return failure();
2045 
2046  auto maps = op.getReassociationMaps();
2047  RankedTensorType expectedType =
2048  CollapseShapeOp::inferCollapsedType(expandedType, maps);
2049  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2050  return op.emitOpError("expected collapsed type to be ")
2051  << expectedType << ", but got " << collapsedType;
2052  return success();
2053 }
2054 
2055 LogicalResult ExpandShapeOp::verify() {
2056  auto srcType = getSrcType();
2057  auto resultType = getResultType();
2058 
2059  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2060  return emitOpError("expected number of static shape dims to be equal to "
2061  "the output rank (")
2062  << resultType.getRank() << ") but found "
2063  << getStaticOutputShape().size() << " inputs instead";
2064 
2065  if ((int64_t)getOutputShape().size() !=
2066  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2067  return emitOpError("mismatch in dynamic dims in output_shape and "
2068  "static_output_shape: static_output_shape has ")
2069  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2070  << " dynamic dims while output_shape has " << getOutputShape().size()
2071  << " values";
2072 
2073  return verifyTensorReshapeOp(*this, resultType, srcType);
2074 }
2075 
2076 LogicalResult CollapseShapeOp::verify() {
2077  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2078 }
2079 
2080 namespace {
2081 /// Reshape of a splat constant can be replaced with a constant of the result
2082 /// type.
2083 template <typename TensorReshapeOp>
2084 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2086  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2087  PatternRewriter &rewriter) const override {
2088  DenseElementsAttr attr;
2089  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2090  return failure();
2091  if (!attr || !attr.isSplat())
2092  return failure();
2094  reshapeOp.getResultType(), attr.getRawData());
2095  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2096  return success();
2097  }
2098 };
2099 
2100 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2101 template <typename TensorReshapeOp>
2102 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2103 public:
2105 
2106  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2107  PatternRewriter &rewriter) const override {
2108  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2109  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2110  return failure();
2111 
2112  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2113  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2114  return success();
2115  }
2116 };
2117 
2118 /// Reshape of a FromElements can be replaced with a FromElements of the
2119 /// result type
2120 template <typename TensorReshapeOp>
2121 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2123  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2124  PatternRewriter &rewriter) const override {
2125  auto fromElements =
2126  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2127  if (!fromElements)
2128  return failure();
2129 
2130  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2131 
2132  if (!shapedTy.hasStaticShape())
2133  return failure();
2134 
2135  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2136  fromElements.getElements());
2137  return success();
2138  }
2139 };
2140 
2141 // Fold CastOp into CollapseShapeOp when adding static information.
2142 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2144 
2145  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2146  PatternRewriter &rewriter) const override {
2147  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2148  if (!tensor::canFoldIntoConsumerOp(castOp))
2149  return failure();
2150 
2151  RankedTensorType srcType =
2152  llvm::cast<RankedTensorType>(castOp.getSource().getType());
2153  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2154  srcType, collapseShapeOp.getReassociationMaps());
2155 
2156  if (newResultType == collapseShapeOp.getResultType()) {
2157  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2158  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2159  });
2160  } else {
2161  auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2162  newResultType, castOp.getSource(),
2163  collapseShapeOp.getReassociation());
2164  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2165  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2166  }
2167  return success();
2168  }
2169 };
2170 
2171 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2172 /// matching constant output_shape operands of the expand. This makes the
2173 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2174 /// propagated further.
2175 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2177 
2178  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2179  PatternRewriter &rewriter) const override {
2180  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2181  if (!canFoldIntoConsumerOp(castOp))
2182  return failure();
2183 
2184  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2186  expandOp.getReassociationIndices();
2187 
2188  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2189  SmallVector<Value> dynamicOutputShape;
2190  auto outputIt = expandOp.getOutputShape().begin();
2191 
2192  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2193  for (uint64_t outDim : innerReassoc) {
2194  if (ShapedType::isStatic(newOutputShape[outDim]))
2195  continue;
2196 
2197  // If the cast's src type is dynamic, don't infer any of the
2198  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2199  // least one of the expanded dimensions to be dynamic if the input is
2200  // dynamic.
2201  Value val = *outputIt;
2202  ++outputIt;
2203  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2204  dynamicOutputShape.push_back(val);
2205  continue;
2206  }
2207 
2208  APInt cst;
2209  if (matchPattern(val, m_ConstantInt(&cst))) {
2210  newOutputShape[outDim] = cst.getSExtValue();
2211  } else {
2212  dynamicOutputShape.push_back(val);
2213  }
2214  }
2215  }
2216 
2217  // Couldn't match any values, nothing to change
2218  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2219  return failure();
2220 
2221  // Calculate the input shape from the output
2222  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2223  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2224  for (auto outDim : reassoc[inDim]) {
2225  auto ofr = newOutputShape[outDim];
2226  if (ShapedType::isDynamic(ofr)) {
2227  newInputShape[inDim] = ShapedType::kDynamic;
2228  break;
2229  }
2230  newInputShape[inDim] *= ofr;
2231  }
2232  }
2233 
2234  SmallVector<OpFoldResult> outputOfr =
2235  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2236  auto inputType = RankedTensorType::get(
2237  newInputShape, expandOp.getSrcType().getElementType());
2238  auto outputType = RankedTensorType::get(
2239  newOutputShape, expandOp.getSrcType().getElementType());
2240  auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2241  expandOp.getSrc());
2242  auto newExpand = ExpandShapeOp::create(
2243  rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2244  expandOp.getReassociationIndices(), outputOfr);
2245  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2246  newExpand.getResult());
2247  return success();
2248  }
2249 };
2250 } // namespace
2251 
2252 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2253  MLIRContext *context) {
2254  results.add<
2257  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2258  FoldReshapeWithSplat<ExpandShapeOp>,
2259  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2260 }
2261 
2262 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2263  MLIRContext *context) {
2264  results.add<
2266  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2267  tensor::DimOp, RankedTensorType>,
2268  FoldReshapeWithConstant<CollapseShapeOp>,
2269  FoldReshapeWithSplat<CollapseShapeOp>,
2270  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2271  context);
2272 }
2273 
2274 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2275  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2276  adaptor.getOperands());
2277 }
2278 
2279 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2280  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2281  adaptor.getOperands());
2282 }
2283 
2284 //===----------------------------------------------------------------------===//
2285 // ExtractSliceOp
2286 //===----------------------------------------------------------------------===//
2287 
2288 void ExtractSliceOp::getAsmResultNames(
2289  function_ref<void(Value, StringRef)> setNameFn) {
2290  setNameFn(getResult(), "extracted_slice");
2291 }
2292 
2293 /// An extract_slice result type can be inferred, when it is not
2294 /// rank-reduced, from the source type and the static representation of
2295 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2296 RankedTensorType ExtractSliceOp::inferResultType(
2297  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2298  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2299  // An extract_slice op may specify only a leading subset of offset/sizes/
2300  // strides in which case we complete with offset=0, sizes from memref type
2301  // and strides=1.
2302  assert(static_cast<int64_t>(staticSizes.size()) ==
2303  sourceTensorType.getRank() &&
2304  "unexpected staticSizes not equal to rank of source");
2305  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2306  sourceTensorType.getEncoding());
2307 }
2308 
2309 // TODO: This uses neither offsets nor strides!
2310 RankedTensorType ExtractSliceOp::inferResultType(
2311  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2313  SmallVector<int64_t> staticSizes;
2314  std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
2315  assert(static_cast<int64_t>(staticSizes.size()) ==
2316  sourceTensorType.getRank() &&
2317  "unexpected staticSizes not equal to rank of source");
2318  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2319  sourceTensorType.getEncoding());
2320 }
2321 
2322 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2323 /// number of sizes), drop as many size 1 as needed to produce an inferred
2324 /// type with the desired rank.
2325 ///
2326 /// Note that there may be multiple ways to compute this rank-reduced type:
2327 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2328 ///
2329 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2330 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2331  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2332  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2333  ArrayRef<int64_t> strides) {
2334  // Type inferred in the absence of rank-reducing behavior.
2335  auto inferredType = llvm::cast<RankedTensorType>(
2336  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2337  int rankDiff = inferredType.getRank() - desiredResultRank;
2338  if (rankDiff > 0) {
2339  auto shape = inferredType.getShape();
2340  llvm::SmallBitVector dimsToProject =
2341  getPositionsOfShapeOne(rankDiff, shape);
2342  SmallVector<int64_t> projectedShape;
2343  // Best effort rank-reducing: drop 1s in order.
2344  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2345  if (!dimsToProject.test(pos))
2346  projectedShape.push_back(shape[pos]);
2347  inferredType =
2348  RankedTensorType::get(projectedShape, inferredType.getElementType());
2349  }
2350  return inferredType;
2351 }
2352 
2353 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2354  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2356  ArrayRef<OpFoldResult> strides) {
2357  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2358  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2359  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2360  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2361  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2362  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2363  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2364  staticStrides);
2365 }
2366 
2367 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2368 /// result type. If the type passed is nullptr, it is inferred.
2369 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2370  RankedTensorType resultType, Value source,
2371  ArrayRef<OpFoldResult> offsets,
2372  ArrayRef<OpFoldResult> sizes,
2373  ArrayRef<OpFoldResult> strides,
2374  ArrayRef<NamedAttribute> attrs) {
2375  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2376  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2377  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2378  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2379  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2380  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2381  // Structuring implementation this way avoids duplication between builders.
2382  if (!resultType) {
2383  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2384  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2385  }
2386  result.addAttributes(attrs);
2387  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2388  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2389  b.getDenseI64ArrayAttr(staticSizes),
2390  b.getDenseI64ArrayAttr(staticStrides));
2391 }
2392 
2393 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2394 /// result type.
2395 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2396  ArrayRef<OpFoldResult> offsets,
2397  ArrayRef<OpFoldResult> sizes,
2398  ArrayRef<OpFoldResult> strides,
2399  ArrayRef<NamedAttribute> attrs) {
2400  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2401 }
2402 
2403 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2404 /// a Range vector.
2405 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2406  ArrayRef<Range> ranges,
2407  ArrayRef<NamedAttribute> attrs) {
2408  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2409  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2410 }
2411 
2412 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2413 /// the type passed is nullptr, it is inferred.
2414 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2415  RankedTensorType resultType, Value source,
2416  ValueRange offsets, ValueRange sizes,
2417  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2418  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2419  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2420  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2421  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2422  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2423  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2424  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2425 }
2426 
2427 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2428 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2429  ValueRange offsets, ValueRange sizes,
2430  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2431  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2432 }
2433 
2434 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2435  Operation *op,
2436  RankedTensorType expectedType) {
2437  switch (result) {
2439  return success();
2441  return op->emitError("expected rank to be smaller or equal to ")
2442  << "the other rank. ";
2444  return op->emitError("expected type to be ")
2445  << expectedType << " or a rank-reduced version. (size mismatch) ";
2447  return op->emitError("expected element type to be ")
2448  << expectedType.getElementType();
2449  default:
2450  llvm_unreachable("unexpected extract_slice op verification result");
2451  }
2452 }
2453 
2454 /// Verifier for ExtractSliceOp.
2455 LogicalResult ExtractSliceOp::verify() {
2456  RankedTensorType sourceType = getSourceType();
2457 
2458  // Verify result type against inferred type.
2459  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2460  sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2461  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2462  if (result != SliceVerificationResult::Success)
2463  return produceSliceErrorMsg(result, *this, expectedType);
2464 
2465  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2466  // to the source tensor.
2468  sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2469  getStaticStrides(), /*generateErrorMessage=*/true);
2470  if (!boundsResult.isValid)
2471  return getOperation()->emitError(boundsResult.errorMessage);
2472 
2473  return success();
2474 }
2475 
2476 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2478 }
2479 
2480 FailureOr<Value>
2481 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2482  ArrayRef<int64_t> desiredShape) {
2483  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2484  assert(sourceTensorType && "not a ranked tensor type");
2485  auto sourceShape = sourceTensorType.getShape();
2486  if (sourceShape.equals(desiredShape))
2487  return value;
2488  auto maybeRankReductionMask =
2489  mlir::computeRankReductionMask(sourceShape, desiredShape);
2490  if (!maybeRankReductionMask)
2491  return failure();
2493  b, loc, value,
2494  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2495 }
2496 
2497 LogicalResult ExtractSliceOp::reifyResultShapes(
2498  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2499  reifiedReturnShapes.resize(1);
2500  reifiedReturnShapes[0].reserve(getType().getRank());
2502  llvm::SmallBitVector droppedDims = getDroppedDims();
2503  for (const auto &size : enumerate(mixedSizes)) {
2504  if (droppedDims.test(size.index()))
2505  continue;
2506  reifiedReturnShapes[0].push_back(size.value());
2507  }
2508  return success();
2509 }
2510 
2511 namespace {
2512 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2513 /// This essentially pushes memref_cast past its consuming slice when
2514 /// `canFoldIntoConsumerOp` is true.
2515 ///
2516 /// Example:
2517 /// ```
2518 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2519 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2520 /// tensor<3x4xf32>
2521 /// ```
2522 /// is rewritten into:
2523 /// ```
2524 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2525 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2526 /// ```
2527 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2528 public:
2530 
2531  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2532  PatternRewriter &rewriter) const override {
2533  // Any constant operand, just return to let the constant folder kick in.
2534  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2535  return matchPattern(operand, matchConstantIndex());
2536  }))
2537  return failure();
2538 
2539  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2540  if (!castOp)
2541  return failure();
2542 
2543  if (!canFoldIntoConsumerOp(castOp))
2544  return failure();
2545 
2546  // Pattern does not apply if the produced op would not verify.
2548  cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2549  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2550  sliceOp.getStaticStrides());
2551  if (!sliceResult.isValid)
2552  return failure();
2553 
2554  // Create folded extract.
2555  Location loc = sliceOp.getLoc();
2556  Value newResult = ExtractSliceOp::create(
2557  rewriter, loc, sliceOp.getType(), castOp.getSource(),
2558  sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2559  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2560  sliceOp.getStaticStrides());
2561  rewriter.replaceOp(sliceOp, newResult);
2562  return success();
2563  }
2564 };
2565 
2566 /// Slice elements from `values` into `outValues`. `counts` represents the
2567 /// numbers of elements to stride in the original values for each dimension.
2568 /// The output values can be used to construct a DenseElementsAttr.
2569 template <typename IterTy, typename ElemTy>
2570 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2571  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2572  ArrayRef<int64_t> strides,
2573  llvm::SmallVectorImpl<ElemTy> *outValues) {
2574  assert(offsets.size() == sizes.size());
2575  assert(offsets.size() == strides.size());
2576  if (offsets.empty())
2577  return;
2578 
2579  int64_t offset = offsets.front();
2580  int64_t size = sizes.front();
2581  int64_t stride = strides.front();
2582  if (offsets.size() == 1) {
2583  for (int64_t i = 0; i < size; ++i, offset += stride)
2584  outValues->push_back(*(values + offset));
2585 
2586  return;
2587  }
2588 
2589  for (int64_t i = 0; i < size; ++i, offset += stride) {
2590  auto begin = values + offset * counts.front();
2591  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2592  offsets.drop_front(), sizes.drop_front(),
2593  strides.drop_front(), outValues);
2594  }
2595 }
2596 
2597 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2598 /// folded operation might introduce more constant data; Users can control
2599 /// their heuristics by the control function.
2600 class ConstantOpExtractSliceFolder final
2601  : public OpRewritePattern<ExtractSliceOp> {
2602 public:
2604 
2605  ConstantOpExtractSliceFolder(MLIRContext *context,
2607  : OpRewritePattern<ExtractSliceOp>(context),
2608  controlFn(std::move(controlFn)) {}
2609 
2610  LogicalResult matchAndRewrite(ExtractSliceOp op,
2611  PatternRewriter &rewriter) const override {
2612  DenseElementsAttr attr;
2613  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2614  return failure();
2615 
2616  // A constant splat is handled by fold().
2617  if (attr.isSplat())
2618  return failure();
2619 
2620  // Dynamic result shape is not supported.
2621  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2622  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2623  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2624  return failure();
2625 
2626  // Customized control over the folding.
2627  if (!controlFn(op))
2628  return failure();
2629 
2630  int64_t count = sourceType.getNumElements();
2631  if (count == 0)
2632  return failure();
2633 
2634  // Check if there are any dynamic parts, which are not supported.
2635  auto offsets = op.getStaticOffsets();
2636  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2637  return failure();
2638  auto sizes = op.getStaticSizes();
2639  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2640  return failure();
2641  auto strides = op.getStaticStrides();
2642  if (llvm::is_contained(strides, ShapedType::kDynamic))
2643  return failure();
2644 
2645  // Compute the stride for each dimension.
2646  SmallVector<int64_t> counts;
2647  ArrayRef<int64_t> shape = sourceType.getShape();
2648  counts.reserve(shape.size());
2649  for (int64_t v : shape) {
2650  count = count / v;
2651  counts.push_back(count);
2652  }
2653 
2654  // New attribute constructed by the sliced values.
2655  DenseElementsAttr newAttr;
2656 
2657  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2658  SmallVector<APInt> outValues;
2659  outValues.reserve(sourceType.getNumElements());
2660  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2661  elems.begin(), counts, offsets, sizes, strides, &outValues);
2662  newAttr = DenseElementsAttr::get(resultType, outValues);
2663  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2664  SmallVector<APFloat> outValues;
2665  outValues.reserve(sourceType.getNumElements());
2666  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2667  elems.begin(), counts, offsets, sizes, strides, &outValues);
2668  newAttr = DenseElementsAttr::get(resultType, outValues);
2669  }
2670 
2671  if (newAttr) {
2672  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2673  return success();
2674  }
2675 
2676  return failure();
2677  }
2678 
2679 private:
2680  /// This additionally controls whether the fold happens or not. Users can
2681  /// impose their heuristics in the function.
2683 };
2684 
2685 } // namespace
2686 
2689  const ControlConstantExtractSliceFusionFn &controlFn) {
2690  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2691 }
2692 
2693 /// Return the canonical type of the result of an extract_slice op.
2695  RankedTensorType operator()(ExtractSliceOp op,
2696  ArrayRef<OpFoldResult> mixedOffsets,
2697  ArrayRef<OpFoldResult> mixedSizes,
2698  ArrayRef<OpFoldResult> mixedStrides) {
2699  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2700  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2701  mixedStrides);
2702  }
2703 };
2704 
2705 /// A canonicalizer wrapper to replace ExtractSliceOps.
2707  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2708  ExtractSliceOp newOp) {
2709  Value replacement = newOp.getResult();
2710  if (replacement.getType() != op.getType())
2711  replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2712  replacement);
2713  rewriter.replaceOp(op, replacement);
2714  }
2715 };
2716 
2717 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2718  MLIRContext *context) {
2719  results.add<
2722  ExtractSliceOpCastFolder>(context);
2723 }
2724 
2725 //
2726 static LogicalResult
2727 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2728  ShapedType shapedType) {
2729  OpBuilder b(op.getContext());
2730  for (OpFoldResult ofr : op.getMixedOffsets())
2731  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2732  return failure();
2733  // Rank-reducing noops only need to inspect the leading dimensions:
2734  // llvm::zip is appropriate.
2735  auto shape = shapedType.getShape();
2736  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2737  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2738  return failure();
2739  for (OpFoldResult ofr : op.getMixedStrides())
2740  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2741  return failure();
2742  return success();
2743 }
2744 
2745 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2746 /// slice, we can return the InsertSliceOp's source directly.
2747 // TODO: This only checks the immediate producer; extend to go up the
2748 // insert/extract chain if the slices are disjoint.
2749 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2750  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2751 
2752  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2753  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2754  insertOp.isSameAs(extractOp, isSame))
2755  return insertOp.getSource();
2756 
2757  return {};
2758 }
2759 
2760 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2761  if (OpFoldResult reshapedSource = reshapeConstantSource(
2762  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2763  getResult().getType()))
2764  return reshapedSource;
2765  if (getSourceType() == getType() &&
2767  return this->getSource();
2768  if (Value slice = foldExtractAfterInsertSlice(*this))
2769  return slice;
2770 
2771  return OpFoldResult();
2772 }
2773 
2775  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2776  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2777  unsigned rank = rankedTensorType.getRank();
2778  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2779  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2780  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2781  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2782  offsets, sizes, strides);
2783 }
2784 
2785 //===----------------------------------------------------------------------===//
2786 // InsertSliceOp
2787 //===----------------------------------------------------------------------===//
2788 
2789 void InsertSliceOp::getAsmResultNames(
2790  function_ref<void(Value, StringRef)> setNameFn) {
2791  setNameFn(getResult(), "inserted_slice");
2792 }
2793 
2794 // Build a InsertSliceOp with mixed static and dynamic entries.
2795 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2796  Value dest, ArrayRef<OpFoldResult> offsets,
2797  ArrayRef<OpFoldResult> sizes,
2798  ArrayRef<OpFoldResult> strides,
2799  ArrayRef<NamedAttribute> attrs) {
2800  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2801  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2802  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2803  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2804  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2805  result.addAttributes(attrs);
2806  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2807  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2808  b.getDenseI64ArrayAttr(staticSizes),
2809  b.getDenseI64ArrayAttr(staticStrides));
2810 }
2811 
2812 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2813 /// Range vector.
2814 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2815  Value dest, ArrayRef<Range> ranges,
2816  ArrayRef<NamedAttribute> attrs) {
2817  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2818  build(b, result, source, dest, offsets, sizes, strides, attrs);
2819 }
2820 
2821 // Build a InsertSliceOp with dynamic entries.
2822 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2823  Value dest, ValueRange offsets, ValueRange sizes,
2824  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2825  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2826  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2827  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2828  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2829  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2830  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2831  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2832 }
2833 
2834 /// Rank-reducing type verification for both InsertSliceOp and
2835 /// ParallelInsertSliceOp.
2837  RankedTensorType srcType, RankedTensorType dstType,
2838  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2839  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2840  // insert_slice is the inverse of extract_slice, use the same type
2841  // inference.
2842  RankedTensorType expected = ExtractSliceOp::inferResultType(
2843  dstType, staticOffsets, staticSizes, staticStrides);
2844  if (expectedType)
2845  *expectedType = expected;
2846  return isRankReducedType(expected, srcType);
2847 }
2848 
2849 /// Verifier for InsertSliceOp.
2850 LogicalResult InsertSliceOp::verify() {
2851  // Verify result type against inferred type.
2852  RankedTensorType expectedType;
2853  SliceVerificationResult result =
2854  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2855  getStaticSizes(), getStaticStrides(), &expectedType);
2856  if (result != SliceVerificationResult::Success)
2857  return produceSliceErrorMsg(result, *this, expectedType);
2858 
2859  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2860  // to the destination tensor.
2862  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2863  getStaticStrides(), /*generateErrorMessage=*/true);
2864  if (!boundsResult.isValid)
2865  return getOperation()->emitError(boundsResult.errorMessage);
2866 
2867  return success();
2868 }
2869 
2870 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2871 /// can mutate the second InsertSliceOp's destination to the first one's.
2872 ///
2873 /// Example:
2874 ///
2875 /// ```mlir
2876 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2877 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2878 /// ```
2879 ///
2880 /// folds into:
2881 ///
2882 /// ```mlir
2883 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2884 /// ```
2885 ///
2886 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2887 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2888  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2889 
2890  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2891  if (!prevInsertOp ||
2892  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2893  !prevInsertOp.isSameAs(insertOp, isSame))
2894  return failure();
2895 
2896  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2897  return success();
2898 }
2899 
2900 /// Folds round-trip extract/insert slice op pairs.
2901 /// Example:
2902 /// ```mlir
2903 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2904 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2905 /// ```
2906 /// can be folded into %val.
2907 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2908  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2909 
2910  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2911  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2912  !extractOp.isSameAs(insertOp, isSame))
2913  return nullptr;
2914 
2915  return extractOp.getSource();
2916 }
2917 
2918 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2919  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2920  getSourceType() == getType() &&
2922  return this->getSource();
2923  if (succeeded(foldInsertAfterInsertSlice(*this)))
2924  return getResult();
2925  if (auto result = foldInsertAfterExtractSlice(*this))
2926  return result;
2927  if (llvm::any_of(getMixedSizes(), isZeroInteger))
2928  return getDest();
2929  return OpFoldResult();
2930 }
2931 
2932 LogicalResult InsertSliceOp::reifyResultShapes(
2933  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2934  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2935  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2936  return success();
2937 }
2938 
2939 namespace {
2940 /// Pattern to rewrite a insert_slice op with constant arguments.
2941 ///
2942 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2943 template <typename InsertOpTy>
2944 class InsertSliceOpConstantArgumentFolder final
2945  : public OpRewritePattern<InsertOpTy> {
2946 public:
2948 
2949  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2950  PatternRewriter &rewriter) const override {
2951  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2952  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2953  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2954 
2955  // No constant operands were folded, just return;
2956  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2957  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2958  failed(foldDynamicStrideList(mixedStrides)))
2959  return failure();
2960 
2961  // Pattern does not apply if the produced op would not verify.
2962  SliceBoundsVerificationResult sliceResult =
2963  verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2964  mixedOffsets, mixedSizes, mixedStrides);
2965  if (!sliceResult.isValid)
2966  return failure();
2967 
2968  // Create the new op in canonical form.
2969  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2970  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2971  mixedOffsets, mixedSizes, mixedStrides);
2972  Value toInsert = insertSliceOp.getSource();
2973  if (sourceType != insertSliceOp.getSourceType()) {
2974  OpBuilder::InsertionGuard g(rewriter);
2975  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2976  // is that the insertion point is just before the InParallelOp in
2977  // the parallel case.
2978  if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2979  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2980  toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2981  sourceType, toInsert);
2982  }
2983  rewriter.replaceOpWithNewOp<InsertOpTy>(
2984  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2985  mixedSizes, mixedStrides);
2986  return success();
2987  }
2988 };
2989 
2990 /// Fold tensor_casts with insert_slice operations. If the source or
2991 /// destination tensor is a tensor_cast that removes static type information,
2992 /// the cast is folded into the insert_slice operation. E.g.:
2993 ///
2994 /// ```mlir
2995 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2996 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2997 /// ```
2998 ///
2999 /// folds into:
3000 ///
3001 /// ```mlir
3002 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3003 /// ```
3004 ///
3005 /// Note: When folding a cast on the destination tensor, the result of the
3006 /// insert_slice operation is casted to ensure that the type of the result did
3007 /// not change.
3008 ///
3009 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3010 template <typename InsertOpTy>
3011 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3013 
3014  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3015  PatternRewriter &rewriter) const override {
3016  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3017  return matchPattern(operand, matchConstantIndex());
3018  }))
3019  return failure();
3020 
3021  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3022  auto castOp = v.getDefiningOp<tensor::CastOp>();
3023  if (!castOp || !canFoldIntoConsumerOp(castOp))
3024  return std::nullopt;
3025  return castOp.getSource();
3026  };
3027  std::optional<Value> sourceCastSource =
3028  getSourceOfCastOp(insertSliceOp.getSource());
3029  std::optional<Value> destCastSource =
3030  getSourceOfCastOp(insertSliceOp.getDest());
3031  if (!sourceCastSource && !destCastSource)
3032  return failure();
3033 
3034  auto src =
3035  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3036  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3037  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3038  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3039  if (!srcType || !dstType)
3040  return failure();
3041 
3042  // The tensor.cast source could have additional static information not seen
3043  // in the insert slice op static sizes, so we ignore dynamic dims when
3044  // computing the rank reduction mask.
3045  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3046  auto rankReductionMask = computeRankReductionMask(
3047  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3048  if (!rankReductionMask.has_value())
3049  return failure();
3050  // Replace dimensions in the insert slice op with corresponding static dims
3051  // from the cast source type. If the insert slice sizes have static dims
3052  // that are not static in the tensor.cast source (i.e., when the cast op
3053  // casts a dynamic dim to static), the dim should not be replaced, and the
3054  // pattern will fail later in `verifyInsertSliceOp`.
3055  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3056  int64_t rankReducedIdx = 0;
3057  for (auto [idx, size] : enumerate(staticSizes)) {
3058  if (!rankReductionMask.value().contains(idx) &&
3059  !srcType.isDynamicDim(rankReducedIdx)) {
3060  mixedSizes[idx] = getAsIndexOpFoldResult(
3061  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3062  size = srcType.getDimSize(rankReducedIdx++);
3063  }
3064  }
3065 
3066  // Pattern does not apply if the produced op would not verify.
3067  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3068  staticSizes, insertSliceOp.getStaticStrides()) !=
3070  return failure();
3071  SliceBoundsVerificationResult sliceResult =
3072  verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3073  mixedSizes, insertSliceOp.getMixedStrides());
3074  if (!sliceResult.isValid)
3075  return failure();
3076 
3077  Operation *replacement =
3078  InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3079  insertSliceOp.getMixedOffsets(), mixedSizes,
3080  insertSliceOp.getMixedStrides());
3081 
3082  // In the parallel case there is no result and so nothing to cast.
3083  bool isParallelInsert =
3084  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3085  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3086  replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3087  insertSliceOp.getDestType(),
3088  replacement->getResult(0));
3089  }
3090  rewriter.replaceOp(insertSliceOp, replacement->getResults());
3091  return success();
3092  }
3093 };
3094 
3095 /// If additional static type information can be deduced from a insert_slice's
3096 /// size operands, insert an explicit cast of the op's source operand. This
3097 /// enables other canonicalization patterns that are matching for tensor_cast
3098 /// ops such as `ForOpTensorCastFolder` in SCF.
3099 ///
3100 /// Example:
3101 ///
3102 /// ```mlir
3103 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3104 /// : tensor<?x?xf32> into ...
3105 /// ```
3106 ///
3107 /// folds into:
3108 ///
3109 /// ```mlir
3110 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3111 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3112 /// : tensor<64x64xf32> into ...
3113 /// ```
3114 ///
3115 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3116 template <typename InsertOpTy>
3117 struct InsertSliceOpSourceCastInserter final
3118  : public OpRewritePattern<InsertOpTy> {
3120 
3121  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3122  PatternRewriter &rewriter) const override {
3123  RankedTensorType srcType = insertSliceOp.getSourceType();
3124  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3125  return failure();
3126  SmallVector<int64_t> newSrcShape(srcType.getShape());
3127  for (int64_t i = 0; i < srcType.getRank(); ++i) {
3128  if (std::optional<int64_t> constInt =
3129  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3130  // Bail on invalid IR.
3131  if (*constInt < 0)
3132  return failure();
3133  newSrcShape[i] = *constInt;
3134  }
3135  }
3136  if (!hasValidSizesOffsets(newSrcShape))
3137  return failure();
3138 
3139  RankedTensorType newSrcType = RankedTensorType::get(
3140  newSrcShape, srcType.getElementType(), srcType.getEncoding());
3141  if (srcType == newSrcType ||
3142  !preservesStaticInformation(srcType, newSrcType) ||
3143  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3144  return failure();
3145 
3146  // newSrcType is:
3147  // 1) Different from srcType.
3148  // 2) "More static" than srcType.
3149  // 3) Cast-compatible with srcType.
3150  // Insert the cast.
3151  OpBuilder::InsertionGuard g(rewriter);
3152  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3153  // that the insertion point is just before the InParallelOp in the
3154  // parallel case.
3155  if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3156  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3157  Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3158  newSrcType, insertSliceOp.getSource());
3159  rewriter.replaceOpWithNewOp<InsertOpTy>(
3160  insertSliceOp, cast, insertSliceOp.getDest(),
3161  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3162  insertSliceOp.getMixedStrides());
3163  return success();
3164  }
3165 };
3166 } // namespace
3167 
3168 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3169  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3170 }
3171 
3172 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3173  MLIRContext *context) {
3174  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3175  InsertSliceOpCastFolder<InsertSliceOp>,
3176  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3177 }
3178 
3180  Location loc,
3181  Value tensor,
3182  Value dest) {
3183  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3184  unsigned rank = rankedTensorType.getRank();
3185  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3186  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3187  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3188  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3189  sizes, strides);
3190 }
3191 
3192 //===----------------------------------------------------------------------===//
3193 // PadOp
3194 //===----------------------------------------------------------------------===//
3195 
3196 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3197  setNameFn(getResult(), "padded");
3198 }
3199 
3200 LogicalResult PadOp::verify() {
3201  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3202  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3203  auto expectedType =
3204  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3205  if (!expectedType) {
3206  return emitError("failed to infer expectedType from sourceType ")
3207  << sourceType << ", specified resultType is " << resultType;
3208  }
3209  if (resultType.getRank() != expectedType.getRank()) {
3210  return emitError("specified type ")
3211  << resultType << " does not match the inferred type "
3212  << expectedType;
3213  }
3214  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3215  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3216  continue;
3217  if (expectedType.isDynamicDim(i))
3218  continue;
3219  return emitError("specified type ")
3220  << resultType << " does not match the inferred type "
3221  << expectedType;
3222  }
3223 
3224  return success();
3225 }
3226 
3227 LogicalResult PadOp::verifyRegions() {
3228  auto &region = getRegion();
3229  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3230  Block &block = region.front();
3231  if (block.getNumArguments() != rank)
3232  return emitError("expected the block to have ") << rank << " arguments";
3233 
3234  // Note: the number and type of yield values are checked in the YieldOp.
3235  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3236  if (!en.value().isIndex())
3237  return emitOpError("expected block argument ")
3238  << (en.index() + 1) << " to be an index";
3239  }
3240 
3241  // Ensure that the region yields an element of the right type.
3242  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3243  if (yieldOp.getValue().getType() !=
3244  llvm::cast<ShapedType>(getType()).getElementType())
3245  return emitOpError("expected yield type to match shape element type");
3246 
3247  return success();
3248 }
3249 
3250 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3251  ArrayRef<int64_t> staticLow,
3252  ArrayRef<int64_t> staticHigh,
3253  ArrayRef<int64_t> resultShape) {
3254  unsigned rank = sourceType.getRank();
3255  if (staticLow.size() != rank)
3256  return RankedTensorType();
3257  if (staticHigh.size() != rank)
3258  return RankedTensorType();
3259  if (!resultShape.empty() && resultShape.size() != rank)
3260  return RankedTensorType();
3261 
3262  SmallVector<int64_t, 4> inferredShape;
3263  for (auto i : llvm::seq<unsigned>(0, rank)) {
3264  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3265  staticHigh[i] == ShapedType::kDynamic) {
3266  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3267  : resultShape[i]);
3268  } else {
3269  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3270  assert((resultShape.empty() || size == resultShape[i] ||
3271  resultShape[i] == ShapedType::kDynamic) &&
3272  "mismatch between inferred shape and result shape");
3273  inferredShape.push_back(size);
3274  }
3275  }
3276 
3277  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3278 }
3279 
3280 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3281  Value source, ArrayRef<int64_t> staticLow,
3282  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3283  bool nofold, ArrayRef<NamedAttribute> attrs) {
3284  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3285  if (!resultType)
3286  resultType = inferResultType(sourceType, staticLow, staticHigh);
3287  result.addAttributes(attrs);
3288  build(b, result, resultType, source, low, high,
3289  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3290  nofold ? b.getUnitAttr() : UnitAttr());
3291 }
3292 
3293 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3294  Value source, ValueRange low, ValueRange high, bool nofold,
3295  ArrayRef<NamedAttribute> attrs) {
3296  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3297  unsigned rank = sourceType.getRank();
3298  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3299  build(b, result, resultType, source, staticVector, staticVector, low, high,
3300  nofold, attrs);
3301 }
3302 
3303 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3304  Value source, ArrayRef<OpFoldResult> low,
3305  ArrayRef<OpFoldResult> high, bool nofold,
3306  ArrayRef<NamedAttribute> attrs) {
3307  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3308  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3309  SmallVector<int64_t, 4> staticLow, staticHigh;
3310  // staticLow and staticHigh have full information of the padding config.
3311  // This will grow staticLow and staticHigh with 1 value. If the config is
3312  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3313  // value as well.
3314  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3315  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3316  if (!resultType) {
3317  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3318  }
3319  assert(llvm::isa<RankedTensorType>(resultType));
3320  result.addAttributes(attrs);
3321  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3322  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3323  nofold ? b.getUnitAttr() : UnitAttr());
3324 }
3325 
3326 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3327  Value source, ArrayRef<OpFoldResult> low,
3328  ArrayRef<OpFoldResult> high, Value constantPadValue,
3329  bool nofold, ArrayRef<NamedAttribute> attrs) {
3330  build(b, result, resultType, source, low, high, nofold, attrs);
3331 
3332  // Add a region and a block to yield the pad value.
3333  Region *region = result.regions[0].get();
3334  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3335  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3336  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3337 
3338  // `builder.createBlock` changes the insertion point within the block. Create
3339  // a guard to reset the insertion point of the builder after it is destroyed.
3340  OpBuilder::InsertionGuard guard(b);
3341  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3342  tensor::YieldOp::create(b, result.location, constantPadValue);
3343 }
3344 
3345 llvm::SmallBitVector PadOp::getPaddedDims() {
3346  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3347  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3348  for (const auto &en : enumerate(paddingWidths))
3349  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3350  paddedDims.set(en.index());
3351  };
3352  extractPaddedDims(getMixedLowPad());
3353  extractPaddedDims(getMixedHighPad());
3354  return paddedDims;
3355 }
3356 
3357 namespace {
3358 // Folds tensor.pad when padding is static zeros and the attribute
3359 // doesn't request otherwise.
3360 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3362 
3363  LogicalResult matchAndRewrite(PadOp padTensorOp,
3364  PatternRewriter &rewriter) const override {
3365  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3366  return failure();
3367  if (padTensorOp.getNofold())
3368  return failure();
3369  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3370  padTensorOp, padTensorOp.getResult().getType(),
3371  padTensorOp.getSource());
3372  return success();
3373  }
3374 };
3375 
3376 // Fold CastOp into PadOp when adding static information.
3377 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3379 
3380  LogicalResult matchAndRewrite(PadOp padTensorOp,
3381  PatternRewriter &rewriter) const override {
3382  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3383  if (!tensor::canFoldIntoConsumerOp(castOp))
3384  return failure();
3385 
3386  auto newResultType = PadOp::inferResultType(
3387  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3388  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3389  padTensorOp.getResultType().getShape());
3390 
3391  if (newResultType == padTensorOp.getResultType()) {
3392  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3393  padTensorOp.getSourceMutable().assign(castOp.getSource());
3394  });
3395  } else {
3396  auto newOp = PadOp::create(
3397  rewriter, padTensorOp->getLoc(), newResultType,
3398  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3399  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3400  padTensorOp.getHigh(), padTensorOp.getNofold(),
3401  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3402  IRMapping mapper;
3403  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3404 
3405  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3406  padTensorOp, padTensorOp.getResultType(), newOp);
3407  }
3408  return success();
3409  }
3410 };
3411 
3412 // Fold CastOp using the result of PadOp back into the latter if it adds
3413 // static information.
3414 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3416 
3417  LogicalResult matchAndRewrite(PadOp padTensorOp,
3418  PatternRewriter &rewriter) const override {
3419  if (!padTensorOp.getResult().hasOneUse())
3420  return failure();
3421  auto tensorCastOp =
3422  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3423  if (!tensorCastOp)
3424  return failure();
3425  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3426  tensorCastOp.getDest().getType()))
3427  return failure();
3428 
3429  auto replacementOp = PadOp::create(
3430  rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3431  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3432  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3433  padTensorOp.getHigh(), padTensorOp.getNofold(),
3434  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3435  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3436 
3437  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3438  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3439  return success();
3440  }
3441 };
3442 
3443 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3444 /// different dimensions. The pattern applies if the following preconditions
3445 /// hold:
3446 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3447 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3448 /// 3) the tensor::PadOps perform only high-padding,
3449 /// 4) the tensor::PadOps have the same constant padding value,
3450 /// 5) the tensor::PadOps do not have common padding dimensions,
3451 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3452 /// zero-offset for every dimension.
3453 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3454 /// the
3455 /// padded source dimensions.
3456 ///
3457 /// Example:
3458 ///
3459 /// ```mlir
3460 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3461 /// : tensor<64x64xf32> to tensor<?x64xf32>
3462 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3463 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3464 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3465 /// : tensor<8x64xf32> to tensor<8x?xf32>
3466 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3467 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3468 /// ```
3469 ///
3470 /// folds into:
3471 ///
3472 /// ```mlir
3473 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3474 /// : tensor<64x64xf32> to tensor<?x?xf32>
3475 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3476 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3477 /// ```
3478 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3480 
3481  LogicalResult matchAndRewrite(PadOp padOp,
3482  PatternRewriter &rewriter) const override {
3483  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3484  if (!innerSliceOp)
3485  return failure();
3486  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3487  if (!outerPadOp || outerPadOp.getNofold())
3488  return failure();
3489  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3490  if (!outerSliceOp)
3491  return failure();
3492 
3493  // 1) Fail if the chain is rank-reducing.
3494  int64_t rank = padOp.getSourceType().getRank();
3495  if (outerSliceOp.getSourceType().getRank() != rank) {
3496  return rewriter.notifyMatchFailure(padOp,
3497  "cannot fold rank-reducing chain");
3498  }
3499 
3500  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3501  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3502  return rewriter.notifyMatchFailure(
3503  padOp, "cannot fold non-unit stride ExtractSliceOps");
3504  }
3505 
3506  // 3) Fail if the tensor::PadOps have non-zero low padding.
3507  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3508  return rewriter.notifyMatchFailure(padOp,
3509  "cannot fold PadOps with low padding");
3510  }
3511 
3512  // 4) Fail if the tensor::PadOps padding values do not match.
3513  Attribute innerAttr, outerAttr;
3514  Value innerValue = padOp.getConstantPaddingValue();
3515  Value outerValue = outerPadOp.getConstantPaddingValue();
3516  if (!innerValue || !outerValue ||
3517  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3518  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3519  innerAttr != outerAttr) {
3520  return rewriter.notifyMatchFailure(
3521  padOp, "cannot fold PadOps with different padding values");
3522  }
3523 
3524  // 5) Fail if a dimension is padded by both tensor::PadOps.
3525  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3526  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3527  if (innerDims.anyCommon(outerDims)) {
3528  return rewriter.notifyMatchFailure(
3529  padOp, "cannot fold PadOps with common padding dimensions");
3530  }
3531 
3532  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3533  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3534  // for every dimension, and use the offset the other pair. Fail if no
3535  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3536  // exists.
3537  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3538  for (auto en : enumerate(newOffsets)) {
3539  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3540  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3541  if (!innerDims.test(en.index()) &&
3542  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3543  en.value() = outerOffset;
3544  continue;
3545  }
3546  if (!outerDims.test(en.index()) &&
3547  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3548  en.value() = innerOffset;
3549  continue;
3550  }
3551  return rewriter.notifyMatchFailure(
3552  padOp, "cannot find zero-offset and zero-padding pair");
3553  }
3554 
3555  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3556  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3557  // outer tensor::PadOp and fail if the size of the inner
3558  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3559  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3560  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3561  for (auto en : enumerate(newSizes)) {
3562  if (!outerDims.test(en.index()))
3563  continue;
3564  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3565  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3566  assert(ShapedType::isStatic(sourceSize) &&
3567  "expected padded dimension to have a static size");
3568  if (getConstantIntValue(sliceSize) != sourceSize) {
3569  return rewriter.notifyMatchFailure(
3570  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3571  "match the size of the outer padding");
3572  }
3573  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3574  }
3575 
3576  // Combine the high paddings of the two tensor::PadOps.
3577  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3578  for (auto en : enumerate(newHighPad)) {
3579  if (innerDims.test(en.index()))
3580  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3581  if (outerDims.test(en.index()))
3582  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3583  }
3584 
3585  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3586  // the two paddings in one step.
3587  auto newSliceOp = ExtractSliceOp::create(
3588  rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3589  newSizes, innerSliceOp.getMixedStrides());
3590  auto newPadOp = PadOp::create(
3591  rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3592  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3593  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3594  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3595  newPadOp.getRegion().begin());
3596  rewriter.replaceOp(padOp, newPadOp.getResult());
3597  return success();
3598  }
3599 };
3600 
3601 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3603 
3604  LogicalResult matchAndRewrite(PadOp padTensorOp,
3605  PatternRewriter &rewriter) const override {
3606  Value input = padTensorOp.getSource();
3607  if (!llvm::isa<RankedTensorType>(input.getType()))
3608  return failure();
3609  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3610  auto inputRank = inputDims.size();
3611 
3612  auto oldResultType =
3613  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3614  if (!oldResultType)
3615  return failure();
3616 
3617  auto outputDims = oldResultType.getShape();
3618 
3619  // Extract the static info from the high and low operands.
3620  SmallVector<int64_t> constOperandsLow;
3621  SmallVector<Value> newLows;
3622  for (auto operand : padTensorOp.getLow()) {
3623  APSInt intOp;
3624  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3625  constOperandsLow.push_back(ShapedType::kDynamic);
3626  newLows.push_back(operand);
3627  continue;
3628  }
3629  constOperandsLow.push_back(intOp.getExtValue());
3630  }
3631  SmallVector<int64_t> constOperandsHigh;
3632  SmallVector<Value> newHighs;
3633  for (auto operand : padTensorOp.getHigh()) {
3634  APSInt intOp;
3635  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3636  constOperandsHigh.push_back(ShapedType::kDynamic);
3637  newHighs.push_back(operand);
3638  continue;
3639  }
3640  constOperandsHigh.push_back(intOp.getExtValue());
3641  }
3642 
3643  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3644  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3645 
3646  // Verify the op is well-formed.
3647  if (inputDims.size() != outputDims.size() ||
3648  inputDims.size() != constLow.size() ||
3649  inputDims.size() != constHigh.size())
3650  return failure();
3651 
3652  auto lowCount = 0;
3653  auto highCount = 0;
3654  for (size_t i = 0; i < inputRank; i++) {
3655  if (constLow[i] == ShapedType::kDynamic)
3656  constLow[i] = constOperandsLow[lowCount++];
3657  if (constHigh[i] == ShapedType::kDynamic)
3658  constHigh[i] = constOperandsHigh[highCount++];
3659  }
3660 
3661  auto staticLow = ArrayRef<int64_t>(constLow);
3662  auto staticHigh = ArrayRef<int64_t>(constHigh);
3663 
3664  // Calculate the output sizes with the static information.
3665  SmallVector<int64_t> newOutDims;
3666  for (size_t i = 0; i < inputRank; i++) {
3667  if (outputDims[i] == ShapedType::kDynamic) {
3668  newOutDims.push_back(
3669  (staticLow[i] == ShapedType::kDynamic ||
3670  staticHigh[i] == ShapedType::kDynamic ||
3671  inputDims[i] == ShapedType::kDynamic
3672  ? ShapedType::kDynamic
3673  : inputDims[i] + staticLow[i] + staticHigh[i]));
3674  } else {
3675  newOutDims.push_back(outputDims[i]);
3676  }
3677  }
3678 
3679  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3680  llvm::all_of(newOutDims,
3681  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3682  return failure();
3683 
3684  // Rewrite the op using the new static type.
3685  auto newResultType = RankedTensorType::get(
3686  newOutDims, padTensorOp.getType().getElementType());
3687  auto newOp = PadOp::create(
3688  rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3689  staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3690  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3691 
3692  IRMapping mapper;
3693  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3694  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3695  newOp);
3696 
3697  return success();
3698  }
3699 };
3700 
3701 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3702 ///
3703 /// Example:
3704 ///
3705 /// ```mlir
3706 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3707 /// tensor.yield %val
3708 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3709 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3710 /// tensor.yield %val
3711 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3712 /// ```
3713 ///
3714 /// folds into:
3715 ///
3716 /// ```mlir
3717 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3718 /// tensor.yield %val
3719 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3720 /// ```
3721 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3723 
3724  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3725  PatternRewriter &rewriter) const override {
3726  if (padOp.getNofold()) {
3727  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3728  }
3729 
3730  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3731  if (!producerPad || producerPad.getNofold()) {
3732  return rewriter.notifyMatchFailure(
3733  padOp, "producer is not a foldable tensor.pad op");
3734  }
3735 
3736  // Fail if the tensor::PadOps padding values do not match.
3737  Value consumerPadValue = padOp.getConstantPaddingValue();
3738  Value producerPadValue = producerPad.getConstantPaddingValue();
3739  if (!consumerPadValue || !producerPadValue ||
3740  consumerPadValue != producerPadValue) {
3741  return rewriter.notifyMatchFailure(
3742  padOp,
3743  "cannot fold PadOps with different or non-constant padding values");
3744  }
3745 
3746  Location loc = padOp.getLoc();
3747  AffineExpr d0, d1;
3748  bindDims(rewriter.getContext(), d0, d1);
3749 
3750  // Combine the low/high paddings of the two tensor::PadOps.
3751  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3752  ArrayRef<OpFoldResult> producerPaddings) {
3753  SmallVector<OpFoldResult> sumPaddings;
3754  for (auto [consumerIndex, producerIndex] :
3755  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3756  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3757  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3758  }
3759  return sumPaddings;
3760  };
3761 
3762  SmallVector<OpFoldResult> newHighPad =
3763  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3764  SmallVector<OpFoldResult> newLowPad =
3765  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3766 
3767  auto newPadOp = tensor::PadOp::create(
3768  rewriter, padOp.getLoc(), padOp.getResultType(),
3769  producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3770  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3771  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3772  newPadOp.getRegion().begin());
3773  rewriter.replaceOp(padOp, newPadOp.getResult());
3774  return success();
3775  }
3776 };
3777 
3778 } // namespace
3779 
3780 LogicalResult
3782  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3783  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3784  SmallVector<OpFoldResult> lp = getMixedLowPad();
3785  SmallVector<OpFoldResult> hp = getMixedHighPad();
3786  for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3787  if (!getType().isDynamicDim(i)) {
3788  reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3789  continue;
3790  }
3791  Location loc = getLoc();
3792  Value dim = b.createOrFold<tensor::DimOp>(
3793  loc, getSource(), arith::ConstantIndexOp::create(b, loc, i));
3794 
3795  AffineExpr d0, d1, d2;
3796  bindDims(b.getContext(), d0, d1, d2);
3797  reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3798  b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3799  }
3800  return success();
3801 }
3802 
3803 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3804  MLIRContext *context) {
3805  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3806  FoldOrthogonalPaddings, FoldStaticPadding,
3807  FoldConsecutiveConstantPadding>(context);
3808 }
3809 
3810 /// Return the padding value of the PadOp if it constant. In this context,
3811 /// "constant" means an actual constant or "defined outside of the block".
3812 ///
3813 /// Values are considered constant in three cases:
3814 /// - A ConstantLike value.
3815 /// - A basic block argument from a different block.
3816 /// - A value defined outside of the block.
3817 ///
3818 /// If the padding value is not constant, an empty Value is returned.
3819 Value PadOp::getConstantPaddingValue() {
3820  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3821  if (!yieldOp)
3822  return {};
3823  Value padValue = yieldOp.getValue();
3824  // Check if yield value is a constant.
3825  if (matchPattern(padValue, m_Constant()))
3826  return padValue;
3827  // Check if yield value is defined inside the PadOp block.
3828  if (padValue.getParentBlock() == &getRegion().front())
3829  return {};
3830  // Else: Yield value defined outside of the PadOp block.
3831  return padValue;
3832 }
3833 
3834 OpFoldResult PadOp::fold(FoldAdaptor) {
3835  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3836  !getNofold())
3837  return getSource();
3838  return {};
3839 }
3840 
3841 //===----------------------------------------------------------------------===//
3842 // ParallelInsertSliceOp
3843 //===----------------------------------------------------------------------===//
3844 
3845 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3846  InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3847  for (const auto &it :
3848  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3849  Operation &nextOp = it.value();
3850  if (&nextOp == getOperation())
3851  return parallelCombiningParent.getParentResult(it.index());
3852  }
3853  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3854 }
3855 
3856 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3857 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3858  Value source, Value dest,
3859  ArrayRef<OpFoldResult> offsets,
3860  ArrayRef<OpFoldResult> sizes,
3861  ArrayRef<OpFoldResult> strides,
3862  ArrayRef<NamedAttribute> attrs) {
3863  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3864  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3865  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3866  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3867  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3868  result.addAttributes(attrs);
3869  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3870  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3871  b.getDenseI64ArrayAttr(staticSizes),
3872  b.getDenseI64ArrayAttr(staticStrides));
3873 }
3874 
3875 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3876 /// packed into a Range vector.
3877 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3878  Value source, Value dest,
3879  ArrayRef<Range> ranges,
3880  ArrayRef<NamedAttribute> attrs) {
3881  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3882  build(b, result, source, dest, offsets, sizes, strides, attrs);
3883 }
3884 
3885 // Build a ParallelInsertSliceOp with dynamic entries.
3886 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3887  Value source, Value dest, ValueRange offsets,
3888  ValueRange sizes, ValueRange strides,
3889  ArrayRef<NamedAttribute> attrs) {
3890  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3891  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3892  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3893  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3894  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3895  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3896  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3897 }
3898 
3899 LogicalResult ParallelInsertSliceOp::verify() {
3900  if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3901  return this->emitError("expected InParallelOpInterface parent, got:")
3902  << *(getOperation()->getParentOp());
3903 
3904  // Verify result type against inferred type.
3905  RankedTensorType expectedType;
3906  SliceVerificationResult result =
3907  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3908  getStaticSizes(), getStaticStrides(), &expectedType);
3909  if (result != SliceVerificationResult::Success)
3910  return produceSliceErrorMsg(result, *this, expectedType);
3911 
3912  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3913  // to the destination tensor.
3915  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3916  getStaticStrides(), /*generateErrorMessage=*/true);
3917  if (!boundsResult.isValid)
3918  return getOperation()->emitError(boundsResult.errorMessage);
3919 
3920  return success();
3921 }
3922 
3923 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3924  RewritePatternSet &results, MLIRContext *context) {
3925  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3926  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3927  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3928 }
3929 
3930 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3931  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3932 }
3933 
3934 // ParallelCombiningOpInterface implementation.
3935 MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3936  return getDestMutable();
3937 }
3938 
3939 Operation *ParallelInsertSliceOp::getIteratingParent() {
3940  // Return the parent InParallelOpInterface's parent.
3941  if (auto combiningOp =
3942  dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3943  return combiningOp->getParentOp();
3944  return nullptr;
3945 }
3946 
3947 //===----------------------------------------------------------------------===//
3948 // ScatterOp
3949 //===----------------------------------------------------------------------===//
3950 
3951 void ScatterOp::getAsmResultNames(
3952  function_ref<void(Value, StringRef)> setNameFn) {
3953  setNameFn(getResult(), "scatter");
3954 }
3955 
3956 LogicalResult ScatterOp::verify() {
3957  int64_t destRank = getDestType().getRank();
3958  ArrayRef<int64_t> scatterDims = getScatterDims();
3959  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3960  getIndicesType().getShape(), destRank,
3961  "scatter", "dest")))
3962  return failure();
3963 
3964  if (!getUnique())
3965  return emitOpError("requires 'unique' attribute to be set");
3966  // TODO: we could also check statically that there are fewer leading index
3967  // tensor dims than the dest dims. If this is not the case, the unique
3968  // attribute cannot be true.
3969 
3970  // Use the GatherOp::inferResultType on the `dest` type and verify the
3971  // expected type matches the source type.
3972  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3973  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3974  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3975  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3976  if (getSourceType() != expectedSourceType &&
3977  getSourceType() != expectedRankReducedSourceType) {
3978  return emitOpError("source type "
3979  "mismatch: "
3980  "expected ")
3981  << expectedSourceType << " or its rank-reduced variant "
3982  << expectedRankReducedSourceType << " (got: " << getSourceType()
3983  << ")";
3984  }
3985 
3986  return success();
3987 }
3988 
3989 //===----------------------------------------------------------------------===//
3990 // SplatOp
3991 //===----------------------------------------------------------------------===//
3992 
3993 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3994  Type aggregateType, ValueRange dynamicSizes) {
3995  build(builder, result, aggregateType, element, dynamicSizes);
3996 }
3997 
3998 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3999  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4000  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
4001  build(builder, result, aggregateType, element, dynamicSizes);
4002 }
4003 
4004 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4005  ArrayRef<OpFoldResult> sizes) {
4006  SmallVector<int64_t> staticShape;
4007  SmallVector<Value> dynamicSizes;
4008  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
4009  build(builder, result, element, staticShape, dynamicSizes);
4010 }
4011 
4012 void SplatOp::getAsmResultNames(
4013  function_ref<void(Value, StringRef)> setNameFn) {
4014  setNameFn(getResult(), "splat");
4015 }
4016 
4017 LogicalResult SplatOp::verify() {
4018  if (getType().getNumDynamicDims() != getDynamicSizes().size())
4019  return emitOpError("incorrect number of dynamic sizes, has ")
4020  << getDynamicSizes().size() << ", expected "
4021  << getType().getNumDynamicDims();
4022  return success();
4023 }
4024 
4025 LogicalResult
4027  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4028  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4029  unsigned ctr = 0;
4030  for (int64_t i = 0; i < getType().getRank(); ++i) {
4031  if (getType().isDynamicDim(i)) {
4032  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4033  } else {
4034  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4035  }
4036  }
4037  return success();
4038 }
4039 
4040 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4041  auto constOperand = adaptor.getInput();
4042  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4043  return {};
4044 
4045  // Do not fold if the splat is not statically shaped
4046  if (!getType().hasStaticShape())
4047  return {};
4048 
4049  // SplatElementsAttr::get treats single value for second arg as being a
4050  // splat.
4051  return SplatElementsAttr::get(getType(), {constOperand});
4052 }
4053 
4054 //===----------------------------------------------------------------------===//
4055 // Common Canonicalizers and Folders.
4056 //===----------------------------------------------------------------------===//
4057 static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4058  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4059  // 2. Exclude DPS ops that are also LoopLike from this interface as they
4060  // might need special handling of attached regions.
4061  if (isa<InsertSliceOp>(op.getOperation()) ||
4062  isa<LoopLikeOpInterface>(op.getOperation()))
4063  return false;
4064 
4065  return hasFoldableTensorCastOperand(op);
4066 }
4067 
4068 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4069 /// the `tensor.cast` has source that is more static than the consuming op.
4070 ///
4071 /// Example:
4072 /// ```mlir
4073 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4074 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4075 /// ```
4076 ///
4077 /// folds into:
4078 ///
4079 /// ```mlir
4080 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4081 /// ```
4082 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4083 /// can add the pattern to their canonicalizers.
4085  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4087  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4088 
4089  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4090  PatternRewriter &rewriter) const override {
4091 
4092  // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4093  // for that instead.
4094  if (!foldTensorCastPrecondition(op) ||
4095  isa<linalg::RelayoutOpInterface>(*op))
4096  return failure();
4097 
4098  SmallVector<Type> newResultTypes(op->getResultTypes());
4099  SmallVector<Value> newOperands =
4100  getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4101 
4102  // Clone op
4103  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4104 
4105  SmallVector<Value, 4> replacements;
4106  replacements.reserve(newOp->getNumResults());
4107  for (auto [oldResult, newResult] :
4108  llvm::zip(op->getResults(), newOp->getResults())) {
4109  if (newResult.getType() != oldResult.getType()) {
4110  replacements.push_back(tensor::CastOp::create(
4111  rewriter, op->getLoc(), oldResult.getType(), newResult));
4112  } else {
4113  replacements.push_back(newResult);
4114  }
4115  }
4116  rewriter.replaceOp(op, replacements);
4117 
4118  return success();
4119  }
4120 };
4121 
4122 //===----------------------------------------------------------------------===//
4123 // TensorDialect
4124 //===----------------------------------------------------------------------===//
4125 
4126 void TensorDialect::getCanonicalizationPatterns(
4127  RewritePatternSet &results) const {
4129 }
4130 
4131 //===----------------------------------------------------------------------===//
4132 // TableGen'd op method definitions
4133 //===----------------------------------------------------------------------===//
4134 
4135 #define GET_OP_CLASSES
4136 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:416
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
Definition: TensorOps.cpp:1368
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1554
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2434
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:4057
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2887
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2727
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2749
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1812
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2836
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:180
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:136
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2907
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:2039
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:368
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:378
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:230
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:241
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1469
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:387
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:356
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2687
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:349
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:365
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3179
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2774
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:124
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
Definition: TensorOps.cpp:1439
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:75
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:266
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:90
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4085
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4089
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2706
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2707
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2694
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2695
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:333
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.