MLIR  21.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/TypeUtilities.h"
30 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallBitVector.h"
34 #include "llvm/ADT/StringRef.h"
35 #include "llvm/Support/MathExtras.h"
36 #include <algorithm>
37 #include <optional>
38 
39 using namespace mlir;
40 using namespace mlir::tensor;
41 
42 using llvm::divideCeilSigned;
43 using llvm::divideFloorSigned;
44 using llvm::mod;
45 
46 /// Materialize a single constant operation from a given attribute value with
47 /// the desired resultant type.
49  Attribute value, Type type,
50  Location loc) {
51  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
52  return op;
53  if (complex::ConstantOp::isBuildableWith(value, type))
54  return builder.create<complex::ConstantOp>(loc, type,
55  llvm::cast<ArrayAttr>(value));
56  return nullptr;
57 }
58 
60  int64_t dim) {
61  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
63  if (tensorType.isDynamicDim(dim))
64  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
65 
66  return builder.getIndexAttr(tensorType.getDimSize(dim));
67 }
68 
70  Location loc, Value value) {
71  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
73  for (int64_t i = 0; i < tensorType.getRank(); ++i)
74  result.push_back(getMixedSize(builder, loc, value, i));
75  return result;
76 }
77 
79  OpResult opResult) {
80  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
81  assert(tensorType && "expected tensor type");
82 
83  // If the op has a destination, it implements DestinationStyleOpInterface and
84  // we can query the destination operand from that interface.
85  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
86  if (destOp)
87  return destOp.getTiedOpOperand(opResult)->get();
88 
89  // Otherwise, create a new destination tensor with the same shape.
91  b.setInsertionPoint(opResult.getDefiningOp());
92 
93  // Compute sizes.
94  SmallVector<OpFoldResult> mixedSizes;
95  if (!tensorType.hasStaticShape()) {
96  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
97  ReifiedRankedShapedTypeDims reifiedShapes;
98  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
99  return failure();
100  mixedSizes = reifiedShapes[opResult.getResultNumber()];
101  } else {
102  // Static shape: Take static sizes directly.
103  for (int64_t sz : tensorType.getShape())
104  mixedSizes.push_back(b.getIndexAttr(sz));
105  }
106 
107  // Create empty tensor.
108  Value emptyTensor =
109  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
110  return emptyTensor;
111 }
112 
114  Operation *op,
115  SmallVector<Value> &result) {
116  for (OpResult opResult : op->getResults()) {
117  if (llvm::isa<TensorType>(opResult.getType())) {
118  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
119  if (failed(destination))
120  return failure();
121  result.push_back(*destination);
122  }
123  }
124  return success();
125 }
126 
128  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
129  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
130  return rtp1.getShape() == rtp2.getShape() &&
131  rtp1.getElementType() == rtp2.getElementType();
132  return false;
133  }
134  return tp1 == tp2; // default implementation
135 }
136 
137 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
138 /// rank-extending tensor.insert_slice op.
139 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
140  ArrayRef<OpFoldResult> mixedSizes) {
141  llvm::SmallBitVector droppedDims(mixedSizes.size());
142  int64_t shapePos = reducedShape.size() - 1;
143 
144  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
145  size_t idx = mixedSizes.size() - size.index() - 1;
146  // Rank-reduced dims must have a static unit dimension.
147  bool isStaticUnitSize =
148  isa<Attribute>(size.value()) &&
149  llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
150 
151  if (shapePos < 0) {
152  // There are no more dims in the reduced shape. All remaining sizes must
153  // be rank-reduced dims.
154  assert(isStaticUnitSize && "expected unit dim");
155  droppedDims.set(idx);
156  continue;
157  }
158 
159  // Dim is preserved if the size is not a static 1.
160  if (!isStaticUnitSize) {
161  --shapePos;
162  continue;
163  }
164 
165  // Dim is preserved if the reduced shape dim is also 1.
166  if (reducedShape[shapePos] == 1) {
167  --shapePos;
168  continue;
169  }
170 
171  // Otherwise: Dim is dropped.
172  droppedDims.set(idx);
173  }
174 
175  assert(shapePos < 0 && "dimension mismatch");
176  return droppedDims;
177 }
178 
179 /// Given a ranked tensor type and a range of values that defines its dynamic
180 /// dimension sizes, turn all dynamic sizes that have a constant value into
181 /// static dimension sizes.
182 static RankedTensorType
183 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
184  SmallVector<Value> &foldedDynamicSizes) {
185  SmallVector<int64_t> staticShape(type.getShape());
186  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
187  "incorrect number of dynamic sizes");
188 
189  // Compute new static and dynamic sizes.
190  unsigned ctr = 0;
191  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
192  if (type.isDynamicDim(i)) {
193  Value dynamicSize = dynamicSizes[ctr++];
194  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
195  if (cst.has_value()) {
196  // Dynamic size must be non-negative.
197  if (cst.value() < 0) {
198  foldedDynamicSizes.push_back(dynamicSize);
199  continue;
200  }
201  staticShape[i] = *cst;
202  } else {
203  foldedDynamicSizes.push_back(dynamicSize);
204  }
205  }
206  }
207 
208  return RankedTensorType::get(staticShape, type.getElementType(),
209  type.getEncoding());
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // BitcastOp
214 //===----------------------------------------------------------------------===//
215 
216 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
217  if (inputs.size() != 1 || outputs.size() != 1)
218  return false;
219  Type a = inputs.front(), b = outputs.front();
220  auto aT = dyn_cast<TensorType>(a);
221  auto bT = dyn_cast<TensorType>(b);
222  if (!aT || !bT)
223  return false;
224 
225  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
226  return false;
227 
228  return succeeded(verifyCompatibleShape(aT, bT));
229 }
230 
231 namespace {
232 
233 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
234 /// operation.
235 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
237 
238  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
239  PatternRewriter &rewriter) const final {
240  auto tensorBitcastOperand =
241  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
242  if (!tensorBitcastOperand)
243  return failure();
244 
245  auto resultType = cast<TensorType>(tensorBitcast.getType());
246  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
247  tensorBitcastOperand.getOperand());
248  return success();
249  }
250 };
251 
252 } // namespace
253 
254 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
255  MLIRContext *context) {
256  results.add<ChainedTensorBitcast>(context);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // CastOp
261 //===----------------------------------------------------------------------===//
262 
263 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
264  setNameFn(getResult(), "cast");
265 }
266 
267 /// Returns true if `target` is a ranked tensor type that preserves static
268 /// information available in the `source` ranked tensor type.
270  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
271  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
272 
273  // Requires RankedTensorType.
274  if (!sourceType || !targetType)
275  return false;
276 
277  // Requires same elemental type.
278  if (sourceType.getElementType() != targetType.getElementType())
279  return false;
280 
281  // Requires same rank.
282  if (sourceType.getRank() != targetType.getRank())
283  return false;
284 
285  // Requires same encoding.
286  if (sourceType.getEncoding() != targetType.getEncoding())
287  return false;
288 
289  // If cast is towards more static sizes along any dimension, don't fold.
290  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
291  if (!ShapedType::isDynamic(std::get<0>(t)) &&
292  ShapedType::isDynamic(std::get<1>(t)))
293  return false;
294  }
295 
296  return true;
297 }
298 
299 /// Determines whether tensor::CastOp casts to a more dynamic version of the
300 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
301 /// implement canonicalization patterns for ops in different dialects that may
302 /// consume the results of tensor.cast operations. Such foldable tensor.cast
303 /// operations are typically inserted as `slice` ops and are canonicalized,
304 /// to preserve the type compatibility of their uses.
305 ///
306 /// Returns true when all conditions are met:
307 /// 1. source and result are ranked tensors with same element type and rank.
308 /// 2. the tensor type has more static information than the result
309 ///
310 /// Example:
311 /// ```mlir
312 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
313 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
314 /// ```
315 ///
316 /// folds into:
317 ///
318 /// ```mlir
319 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
320 /// ```
322  if (!castOp)
323  return false;
324 
325  // Can fold if the source of cast has at least as much static information as
326  // its results.
327  return preservesStaticInformation(castOp.getType(),
328  castOp.getSource().getType());
329 }
330 
331 /// Determines whether the tensor::CastOp casts to a more static version of the
332 /// source tensor. This is useful to fold into a producing op and implement
333 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
334 /// being from different dialects. Returns true when all conditions are met:
335 /// 1. source and result and ranked tensors with same element type and rank.
336 /// 2. the result type has more static information than the source.
337 ///
338 /// Example:
339 /// ```mlir
340 /// %1 = producer ... : tensor<?x?xf32>
341 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
342 /// ```
343 ///
344 /// can be canonicalized to :
345 ///
346 /// ```mlir
347 /// %2 = producer ... : tensor<8x16xf32>
348 /// ```
349 /// Not all ops might be canonicalizable this way, but for those that can be,
350 /// this method provides a check that it is worth doing the canonicalization.
352  if (!castOp)
353  return false;
354  return preservesStaticInformation(castOp.getSource().getType(),
355  castOp.getType());
356 }
357 
359  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
360  if (llvm::isa<BlockArgument>(opOperand.get()))
361  return false;
362  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363  return castOp && canFoldIntoConsumerOp(castOp);
364  });
365 }
366 
368  DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
369  SmallVector<Value> newOperands;
370  newOperands.reserve(op->getNumOperands());
371 
372  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
373 
374  // Assumes that the result has dpsInits followed by nonDpsInits.
375  int64_t dpsInitIdx = 0;
376  for (OpOperand &opOperand : op->getOpOperands()) {
377  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
378  bool fold = canFoldIntoConsumerOp(tensorCastOp);
379  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380  if (op.isDpsInit(&opOperand) &&
381  !llvm::isa<MemRefType>(newOperands.back().getType()))
382  newResTy[dpsInitIdx++] = newOperands.back().getType();
383  }
384  return newOperands;
385 }
386 
387 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
388 /// that can be folded.
390  bool folded = false;
391  for (OpOperand &operand : op->getOpOperands()) {
392  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
393  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
394  operand.set(castOp.getOperand());
395  folded = true;
396  }
397  }
398  return success(folded);
399 }
400 
401 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
402  if (inputs.size() != 1 || outputs.size() != 1)
403  return false;
404  Type a = inputs.front(), b = outputs.front();
405  auto aT = llvm::dyn_cast<TensorType>(a);
406  auto bT = llvm::dyn_cast<TensorType>(b);
407  if (!aT || !bT)
408  return false;
409 
410  if (aT.getElementType() != bT.getElementType())
411  return false;
412 
413  return succeeded(verifyCompatibleShape(aT, bT));
414 }
415 
416 /// Compute a TensorType that has the joined shape knowledge of the two
417 /// given TensorTypes. The element types need to match.
419  assert(one.getElementType() == two.getElementType());
420 
421  if (!one.hasRank())
422  return two;
423  if (!two.hasRank())
424  return one;
425 
426  int64_t rank = one.getRank();
427  if (rank != two.getRank())
428  return {};
429 
431  join.reserve(rank);
432  for (int64_t i = 0; i < rank; ++i) {
433  if (one.isDynamicDim(i)) {
434  join.push_back(two.getDimSize(i));
435  continue;
436  }
437  if (two.isDynamicDim(i)) {
438  join.push_back(one.getDimSize(i));
439  continue;
440  }
441  if (one.getDimSize(i) != two.getDimSize(i))
442  return {};
443  join.push_back(one.getDimSize(i));
444  }
445  return RankedTensorType::get(join, one.getElementType());
446 }
447 
448 namespace {
449 
450 /// Replaces chains of two tensor.cast operations by a single tensor.cast
451 /// operation if doing so does not remove runtime constraints.
452 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
454 
455  LogicalResult matchAndRewrite(CastOp tensorCast,
456  PatternRewriter &rewriter) const final {
457  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
458 
459  if (!tensorCastOperand)
460  return failure();
461 
462  auto sourceType =
463  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
466 
467  // We can remove the intermediate cast if joining all three produces the
468  // same result as just joining the source and result shapes.
469  auto firstJoin =
470  joinShapes(joinShapes(sourceType, intermediateType), resultType);
471 
472  // The join might not exist if the cast sequence would fail at runtime.
473  if (!firstJoin)
474  return failure();
475 
476  // The newJoin always exists if the above join exists, it might just contain
477  // less information. If so, we cannot drop the intermediate cast, as doing
478  // so would remove runtime checks.
479  auto newJoin = joinShapes(sourceType, resultType);
480  if (firstJoin != newJoin)
481  return failure();
482 
483  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484  tensorCastOperand.getOperand());
485  return success();
486  }
487 };
488 
489 /// Fold tensor.cast into tesor.extract_slice producer.
490 /// Example:
491 /// ```
492 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
493 /// tensor<128x512xf32> to tensor<?x512xf32>
494 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
495 /// ```
496 /// ->
497 /// ```
498 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
499 /// tensor<128x512xf32> to tensor<16x512xf32>
500 /// ```
501 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
503 
504  LogicalResult matchAndRewrite(CastOp tensorCast,
505  PatternRewriter &rewriter) const final {
506  auto extractOperand =
507  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508 
509  // Cannot fold cast to unranked tensor.
510  auto rankedResultType =
511  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512  if (!rankedResultType)
513  return failure();
514 
515  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
516  rankedResultType.getShape() ==
517  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
518  .getShape())
519  return failure();
520 
521  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
522  auto dimMask = computeRankReductionMask(
523  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
524  size_t dimIndex = 0;
525  for (size_t i = 0, e = sizes.size(); i < e; i++) {
526  if (dimMask && dimMask->count(i))
527  continue;
528  int64_t dim = rankedResultType.getShape()[dimIndex++];
529  if (ShapedType::isDynamic(dim))
530  continue;
531  sizes[i] = rewriter.getIndexAttr(dim);
532  }
533 
534  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535  tensorCast, rankedResultType, extractOperand.getSource(),
536  extractOperand.getMixedOffsets(), sizes,
537  extractOperand.getMixedStrides());
538  return success();
539  }
540 };
541 
542 } // namespace
543 
544 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
545  MLIRContext *context) {
546  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // ConcatOp
551 //===----------------------------------------------------------------------===//
552 
553 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
554  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
555  auto tensorTypes =
556  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
557  return llvm::cast<RankedTensorType>(type);
558  }));
559  int64_t concatRank = tensorTypes[0].getRank();
560 
561  // The concatenation dim must be in the range [0, rank).
562  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
563 
564  SmallVector<int64_t> sizes(concatRank);
565  for (int64_t i = 0, e = concatRank; i < e; ++i) {
566  if (i == dim)
567  continue;
568  SaturatedInteger size;
569  for (auto tensorType : tensorTypes)
570  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
571  sizes[i] = size.asInteger();
572  }
573  auto concatSize = SaturatedInteger::wrap(0);
574  for (auto tensorType : tensorTypes)
575  concatSize =
576  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
577  sizes[dim] = concatSize.asInteger();
578  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
579 }
580 
581 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
582  ValueRange inputs) {
583  FailureOr<RankedTensorType> resultType =
584  inferResultType(dim, inputs.getTypes());
585  assert(succeeded(resultType) && "failed to infer concatenation result type");
586  build(builder, result, *resultType, dim, inputs);
587 }
588 
589 LogicalResult ConcatOp::verify() {
590  if (getInputs().size() < 1)
591  return emitOpError("requires at least one input");
592 
594  for (auto input : getInputs())
595  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
596 
597  RankedTensorType resultType = getResultType();
598  int64_t resultRank = getRank();
599  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
600  return type.getRank() != resultRank;
601  }))
602  return emitOpError("rank of concatenated inputs must match result rank");
603 
604  Type resultElementType = resultType.getElementType();
605  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
606  return type.getElementType() != resultElementType;
607  }))
608  return emitOpError("inputs and result element type must match");
609 
610  int64_t dim = getDim();
611  if (dim >= resultRank)
612  return emitOpError("concatenation dim must be less than the tensor rank");
613 
614  SmallVector<int64_t> sizes(resultRank);
615  for (int64_t i = 0, e = resultRank; i < e; ++i) {
616  if (i == dim)
617  continue;
618  SaturatedInteger size;
619  for (auto tensorType : inputTypes) {
620  FailureOr<SaturatedInteger> maybeSize =
621  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
622  if (failed(maybeSize))
623  return emitOpError("static concatenation size mismatch along ")
624  << "non-concatenated dimension " << i;
625  size = *maybeSize;
626  }
627  sizes[i] = size.asInteger();
628  }
629  auto concatSize = SaturatedInteger::wrap(0);
630  for (auto tensorType : inputTypes)
631  concatSize =
632  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
633  sizes[dim] = concatSize.asInteger();
634  auto inferredResultType =
635  RankedTensorType::get(sizes, inputTypes[0].getElementType());
636 
637  for (auto [inferredSize, actualSize] :
638  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
639  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
640  ShapedType::isDynamic(actualSize);
641  if (!hasDynamic && inferredSize != actualSize)
642  return emitOpError("result type ")
643  << resultType << "does not match inferred shape "
644  << inferredResultType << " static sizes";
645  }
646 
647  return success();
648 }
649 
650 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
651  size_t numInputs = getInputs().size();
652  uint64_t concatDim = getDim();
653 
655  inputShapes.reserve(numInputs);
656  SmallVector<OpFoldResult> concatOffsets;
657  concatOffsets.reserve(numInputs);
658  SmallVector<OpFoldResult> outputShape;
659 
660  AffineExpr addExpr =
661  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
662  OpFoldResult zero = builder.getIndexAttr(0);
663  Location loc = getLoc();
664  for (auto [index, input] : llvm::enumerate(getInputs())) {
665  SmallVector<OpFoldResult> inputShape =
666  tensor::getMixedSizes(builder, input.getLoc(), input);
667  if (index == 0) {
668  outputShape = inputShape;
669  concatOffsets.push_back(zero);
670  } else {
671  concatOffsets.push_back(outputShape[concatDim]);
672  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
673  builder, loc, addExpr,
674  {outputShape[concatDim], inputShape[concatDim]});
675  }
676  inputShapes.emplace_back(std::move(inputShape));
677  }
678 
679  Value replacement = builder.create<tensor::EmptyOp>(
680  loc, outputShape, getType().getElementType());
681 
682  int64_t rank = getType().getRank();
683  OpFoldResult one = builder.getIndexAttr(1);
684  SmallVector<OpFoldResult> strides(rank, one);
685  SmallVector<OpFoldResult> offsets(rank, zero);
686  for (auto [index, input] : llvm::enumerate(getInputs())) {
687  offsets[concatDim] = concatOffsets[index];
688  auto insertSlice = builder.create<tensor::InsertSliceOp>(
689  loc, input, replacement, offsets, inputShapes[index], strides);
690  replacement = insertSlice.getResult();
691  }
692  if (replacement.getType() != getType()) {
693  replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
694  }
695  return SmallVector<Value>{replacement};
696 }
697 
698 LogicalResult
700  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
701  ValueRange inputs = getInputs();
702  int64_t dim = getDim();
703  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
704 
705  Value init = inputs[0];
706  int64_t rank = getType().getRank();
707 
708  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
709 
710  // Pre-populate the result sizes with as much static information as possible
711  // from the given result type, as well as the inferred result type, otherwise
712  // use the dim sizes from the first input.
713  for (int64_t i = 0; i < rank; ++i) {
714  if (i == dim)
715  continue;
716  if (!getType().isDynamicDim(i)) {
717  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
718  } else if (!inferredResultType.isDynamicDim(i)) {
719  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
720  builder, getLoc(),
721  builder.getIndexAttr(inferredResultType.getDimSize(i)));
722  } else {
723  reifiedReturnShapes[0][i] =
724  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
725  }
726  }
727 
728  if (getType().isDynamicDim(dim)) {
729  // Take the sum of the input sizes along the concatenated dim.
730  AffineExpr sum = builder.getAffineDimExpr(0);
731  SmallVector<OpFoldResult> sizes = {
732  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
733  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
734  sum = sum + builder.getAffineDimExpr(idx + 1);
735  sizes.push_back(
736  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
737  }
738  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
739  builder, getLoc(),
740  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
741  } else {
742  // If the result shape is static along the concatenated dim, use the static
743  // shape.
744  reifiedReturnShapes[0][dim] =
745  builder.getIndexAttr(getType().getDimSize(dim));
746  }
747  return success();
748 }
749 
750 void ConcatOp::getAsmResultNames(
751  function_ref<void(Value, StringRef)> setNameFn) {
752  setNameFn(getResult(), "concat");
753 }
754 
755 OpFoldResult ConcatOp::fold(FoldAdaptor) {
756  ValueRange inputs = getInputs();
757  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
758  return inputs[0];
759  return {};
760 }
761 
762 namespace {
763 /// Fold a concat op with a single input to a cast.
764 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
766 
767  LogicalResult matchAndRewrite(ConcatOp concatOp,
768  PatternRewriter &rewriter) const override {
769  if (concatOp.getInputs().size() != 1)
770  return failure();
771  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
772  concatOp.getInputs()[0]);
773  return success();
774  }
775 };
776 } // namespace
777 
778 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
779  MLIRContext *context) {
780  results.add<SingleInputConcatOp>(context);
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // DimOp
785 //===----------------------------------------------------------------------===//
786 
787 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
788  setNameFn(getResult(), "dim");
789 }
790 
791 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
792  int64_t index) {
793  auto loc = result.location;
794  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
795  build(builder, result, source, indexValue);
796 }
797 
798 std::optional<int64_t> DimOp::getConstantIndex() {
799  return getConstantIntValue(getIndex());
800 }
801 
802 Speculation::Speculatability DimOp::getSpeculatability() {
803  auto constantIndex = getConstantIndex();
804  if (!constantIndex)
806 
807  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
808  if (!rankedSourceType)
810 
811  if (rankedSourceType.getRank() <= constantIndex)
813 
815 }
816 
817 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
818  SetIntLatticeFn setResultRange) {
819  setResultRange(getResult(),
820  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
821 }
822 
823 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
824  // All forms of folding require a known index.
825  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
826  if (!index)
827  return {};
828 
829  // Folding for unranked types (UnrankedTensorType) is not supported.
830  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
831  if (!tensorType)
832  return {};
833 
834  // Out of bound indices produce undefined behavior but are still valid IR.
835  // Don't choke on them.
836  int64_t indexVal = index.getInt();
837  if (indexVal < 0 || indexVal >= tensorType.getRank())
838  return {};
839 
840  // Fold if the shape extent along the given index is known.
841  if (!tensorType.isDynamicDim(index.getInt())) {
842  Builder builder(getContext());
843  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
844  }
845 
846  Operation *definingOp = getSource().getDefiningOp();
847 
848  // Fold dim to the operand of tensor.generate.
849  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
850  auto resultType =
851  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
852  // The case where the type encodes the size of the dimension is handled
853  // above.
854  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
855 
856  // Find the operand of the fromElements that corresponds to this index.
857  auto dynExtents = fromElements.getDynamicExtents().begin();
858  for (auto dim : resultType.getShape().take_front(index.getInt()))
859  if (ShapedType::isDynamic(dim))
860  dynExtents++;
861 
862  return Value{*dynExtents};
863  }
864 
865  // The size at the given index is now known to be a dynamic size.
866  unsigned unsignedIndex = index.getValue().getZExtValue();
867 
868  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
869  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
870  // `resolve-shaped-type-result-dims` pass.
871  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
872  sliceOp.isDynamicSize(unsignedIndex)) {
873  return {sliceOp.getDynamicSize(unsignedIndex)};
874  }
875  }
876 
877  // dim(cast) -> dim
878  if (succeeded(foldTensorCast(*this)))
879  return getResult();
880 
881  return {};
882 }
883 
884 namespace {
885 /// Fold dim of a cast into the dim of the source of the tensor cast.
886 struct DimOfCastOp : public OpRewritePattern<DimOp> {
888 
889  LogicalResult matchAndRewrite(DimOp dimOp,
890  PatternRewriter &rewriter) const override {
891  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
892  if (!castOp)
893  return failure();
894  Value newSource = castOp.getOperand();
895  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
896  return success();
897  }
898 };
899 
900 /// Fold dim of a destination passing style op into the dim of the corresponding
901 /// init.
902 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
904 
905  LogicalResult matchAndRewrite(DimOp dimOp,
906  PatternRewriter &rewriter) const override {
907  auto source = dimOp.getSource();
908  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
909  if (!destOp)
910  return failure();
911 
912  auto resultIndex = cast<OpResult>(source).getResultNumber();
913  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
914 
915  rewriter.modifyOpInPlace(
916  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
917  return success();
918  }
919 };
920 
921 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
922 /// operand.
923 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
925 
926  LogicalResult matchAndRewrite(DimOp dim,
927  PatternRewriter &rewriter) const override {
928  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
929 
930  if (!reshape)
931  return failure();
932 
933  // Since tensors are immutable we don't need to worry about where to place
934  // the extract call
935  rewriter.setInsertionPointAfter(dim);
936  Location loc = dim.getLoc();
937  Value extract =
938  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
939  if (extract.getType() != dim.getType())
940  extract =
941  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
942  rewriter.replaceOp(dim, extract);
943  return success();
944  }
945 };
946 } // namespace
947 
948 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
949  MLIRContext *context) {
950  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // EmptyOp
955 //===----------------------------------------------------------------------===//
956 
957 void EmptyOp::build(OpBuilder &builder, OperationState &result,
958  ArrayRef<int64_t> staticShape, Type elementType,
959  Attribute encoding) {
960  assert(all_of(staticShape,
961  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
962  "expected only static sizes");
963  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
964 }
965 
966 void EmptyOp::build(OpBuilder &builder, OperationState &result,
967  ArrayRef<int64_t> staticShape, Type elementType,
968  ValueRange dynamicSizes, Attribute encoding) {
969  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
970  build(builder, result, tensorType, dynamicSizes);
971 }
972 
973 void EmptyOp::build(OpBuilder &builder, OperationState &result,
974  ArrayRef<OpFoldResult> sizes, Type elementType,
975  Attribute encoding) {
976  SmallVector<int64_t> staticShape;
977  SmallVector<Value> dynamicSizes;
978  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
979  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
980 }
981 
982 LogicalResult EmptyOp::verify() {
983  if (getType().getNumDynamicDims() != getDynamicSizes().size())
984  return emitOpError("incorrect number of dynamic sizes, has ")
985  << getDynamicSizes().size() << ", expected "
986  << getType().getNumDynamicDims();
987  return success();
988 }
989 
990 LogicalResult
992  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
993  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
994  unsigned ctr = 0;
995  for (int64_t i = 0; i < getType().getRank(); ++i) {
996  if (getType().isDynamicDim(i)) {
997  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
998  } else {
999  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1000  }
1001  }
1002  return success();
1003 }
1004 
1005 Value EmptyOp::getDynamicSize(unsigned idx) {
1006  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1007  unsigned ctr = 0;
1008  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1009  if (getType().isDynamicDim(i))
1010  ++ctr;
1011  return getDynamicSizes()[ctr];
1012 }
1013 
1016  unsigned ctr = 0;
1017  OpBuilder b(getContext());
1018  for (int64_t i = 0; i < getType().getRank(); ++i) {
1019  if (getType().isDynamicDim(i)) {
1020  result.push_back(getDynamicSizes()[ctr++]);
1021  } else {
1022  result.push_back(b.getIndexAttr(getType().getShape()[i]));
1023  }
1024  }
1025  return result;
1026 }
1027 
1028 namespace {
1029 /// Change the type of the result of a `tensor.empty` by making the result
1030 /// type statically sized along dimensions that in the original operation were
1031 /// defined as dynamic, but the size was defined using a `constant` op. For
1032 /// example
1033 ///
1034 /// %c5 = arith.constant 5: index
1035 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1036 ///
1037 /// to
1038 ///
1039 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1040 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1042 
1043  LogicalResult matchAndRewrite(EmptyOp op,
1044  PatternRewriter &rewriter) const override {
1045  SmallVector<Value> foldedDynamicSizes;
1046  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1047  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1048 
1049  // Stop here if no dynamic size was promoted to static.
1050  if (foldedTensorType == op.getType())
1051  return failure();
1052 
1053  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1054  foldedDynamicSizes);
1055  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1056  return success();
1057  }
1058 };
1059 
1060 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1062 
1063  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1064  PatternRewriter &rewriter) const override {
1065  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1066  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1067  if (!emptyTensorOp || !maybeConstantIndex)
1068  return failure();
1069  auto emptyTensorType = emptyTensorOp.getType();
1070  if (*maybeConstantIndex < 0 ||
1071  *maybeConstantIndex >= emptyTensorType.getRank() ||
1072  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1073  return failure();
1074  rewriter.replaceOp(dimOp,
1075  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1076  return success();
1077  }
1078 };
1079 
1080 /// Canonicalize
1081 ///
1082 /// ```mlir
1083 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1084 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1085 /// ```
1086 ///
1087 /// into
1088 ///
1089 /// ```mlir
1090 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1091 /// ```
1092 ///
1093 /// This assumes the input program is correct in terms of its shape. So it is
1094 /// safe to assume that `%d0` is in fact 4.
1095 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1097 
1098  LogicalResult matchAndRewrite(CastOp castOp,
1099  PatternRewriter &rewriter) const override {
1100  if (!canFoldIntoProducerOp(castOp))
1101  return failure();
1102  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1103  if (!producer)
1104  return failure();
1105 
1106  auto resultType =
1107  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1108  ArrayRef<int64_t> resultShape = resultType.getShape();
1109  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1110  SmallVector<OpFoldResult> newMixedSizes;
1111  newMixedSizes.reserve(currMixedSizes.size());
1112  assert(resultShape.size() == currMixedSizes.size() &&
1113  "mismatch in result shape and sizes of empty op");
1114  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1115  int64_t newDim = std::get<0>(it);
1116  OpFoldResult currDim = std::get<1>(it);
1117  // Case 1: The empty tensor dim is static. Check that the tensor cast
1118  // result dim matches.
1119  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1120  if (ShapedType::isDynamic(newDim) ||
1121  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1122  // Something is off, the cast result shape cannot be more dynamic
1123  // than the empty tensor result shape (enforced by
1124  // `canFoldIntoProducer`). Abort for now.
1125  return rewriter.notifyMatchFailure(
1126  producer, "mismatch in static value of shape of empty tensor "
1127  "result and cast result");
1128  }
1129  newMixedSizes.push_back(attr);
1130  continue;
1131  }
1132 
1133  // Case 2 : The tensor cast shape is static, but empty tensor result
1134  // shape is dynamic.
1135  if (!ShapedType::isDynamic(newDim)) {
1136  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1137  continue;
1138  }
1139 
1140  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1141  // shape is dynamic. Use the dynamic value from the empty tensor op.
1142  newMixedSizes.push_back(currDim);
1143  }
1144 
1145  // TODO: Do not drop tensor encoding.
1146  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1147  resultType.getElementType());
1148  return success();
1149  }
1150 };
1151 
1152 } // namespace
1153 
1154 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1155  MLIRContext *context) {
1156  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1157  ReplaceEmptyTensorStaticShapeDims>(context);
1158 }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // ExtractOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 namespace {
1165 
1166 /// Canonicalizes the pattern of the form
1167 ///
1168 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1169 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1170 ///
1171 /// to
1172 ///
1173 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1174 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1176 
1177  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1178  PatternRewriter &rewriter) const final {
1179  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1180  if (!tensorCast)
1181  return failure();
1182  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1183  return failure();
1184  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1185  extract, tensorCast.getSource(), extract.getIndices());
1186  return success();
1187  }
1188 };
1189 
1190 } // namespace
1191 
1192 void ExtractOp::getAsmResultNames(
1193  function_ref<void(Value, StringRef)> setNameFn) {
1194  setNameFn(getResult(), "extracted");
1195 }
1196 
1197 LogicalResult ExtractOp::verify() {
1198  // Verify the # indices match if we have a ranked type.
1199  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1200  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1201  return emitOpError("incorrect number of indices for extract_element");
1202  return success();
1203 }
1204 
1205 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1206  if (Attribute tensor = adaptor.getTensor()) {
1207  // If this is a splat elements attribute, simply return the value.
1208  // All of the elements of a splat attribute are the same.
1209  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1210  return splatTensor.getSplatValue<Attribute>();
1211 
1212  // If this is a dense resource elements attribute, return.
1213  if (isa<DenseResourceElementsAttr>(tensor))
1214  return {};
1215  }
1216 
1217  // Collect the constant indices into the tensor.
1218  SmallVector<uint64_t, 8> indices;
1219  for (Attribute indice : adaptor.getIndices()) {
1220  if (!indice || !llvm::isa<IntegerAttr>(indice))
1221  return {};
1222  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1223  }
1224 
1225  // Fold extract(from_elements(...)).
1226  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1227  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1228  auto rank = tensorType.getRank();
1229  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1230  "rank mismatch");
1231  int flatIndex = 0;
1232  int stride = 1;
1233  for (int i = rank - 1; i >= 0; --i) {
1234  flatIndex += indices[i] * stride;
1235  stride *= tensorType.getDimSize(i);
1236  }
1237  // Prevent out of bounds accesses. This can happen in invalid code that
1238  // will never execute.
1239  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1240  flatIndex < 0)
1241  return {};
1242  return fromElementsOp.getElements()[flatIndex];
1243  }
1244 
1245  // If this is an elements attribute, query the value at the given indices.
1246  if (Attribute tensor = adaptor.getTensor()) {
1247  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1248  if (elementsAttr && elementsAttr.isValidIndex(indices))
1249  return elementsAttr.getValues<Attribute>()[indices];
1250  }
1251 
1252  return {};
1253 }
1254 
1255 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1256  MLIRContext *context) {
1257  results.add<ExtractFromTensorCast>(context);
1258 }
1259 
1260 //===----------------------------------------------------------------------===//
1261 // FromElementsOp
1262 //===----------------------------------------------------------------------===//
1263 
1264 void FromElementsOp::getAsmResultNames(
1265  function_ref<void(Value, StringRef)> setNameFn) {
1266  setNameFn(getResult(), "from_elements");
1267 }
1268 
1269 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1270  ValueRange elements) {
1271  assert(!elements.empty() && "expected at least one element");
1272  Type resultType = RankedTensorType::get(
1273  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1274  build(builder, result, resultType, elements);
1275 }
1276 
1277 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1278  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1279  return DenseElementsAttr::get(getType(), adaptor.getElements());
1280  return {};
1281 }
1282 
1283 namespace {
1284 
1285 // Pushes the index_casts that occur before extractions to after the extract.
1286 // This minimizes type conversion in some cases and enables the extract
1287 // canonicalizer. This changes:
1288 //
1289 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1290 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1291 //
1292 // to the following:
1293 //
1294 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1295 // %cast = arith.index_cast %extract : i32 to index
1296 //
1297 // to just %element.
1298 //
1299 // Consider expanding this to a template and handle all tensor cast
1300 // operations.
1301 struct ExtractElementFromIndexCast
1302  : public OpRewritePattern<tensor::ExtractOp> {
1304 
1305  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1306  PatternRewriter &rewriter) const final {
1307  Location loc = extract.getLoc();
1308  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1309  if (!indexCast)
1310  return failure();
1311 
1312  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1313 
1314  auto newExtract = rewriter.create<tensor::ExtractOp>(
1315  loc, elementTy, indexCast.getIn(), extract.getIndices());
1316 
1317  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1318  newExtract);
1319 
1320  return success();
1321  }
1322 };
1323 
1324 } // namespace
1325 
1326 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1327  MLIRContext *context) {
1328  results.add<ExtractElementFromIndexCast>(context);
1329 }
1330 
1331 //===----------------------------------------------------------------------===//
1332 // GatherOp
1333 //===----------------------------------------------------------------------===//
1334 
1335 void GatherOp::getAsmResultNames(
1336  function_ref<void(Value, StringRef)> setNameFn) {
1337  setNameFn(getResult(), "gather");
1338 }
1339 
1340 /// Return the inferred result type for a gatherOp where:
1341 /// - sourceType is the type of the source tensor gathered from
1342 /// - indicesType is the type of the indices used to gather
1343 /// - gatherDims are the dims along which the gather occurs.
1344 /// Return a full rank or ranked-reduced variant of the type depending on
1345 /// the value of rankReduced.
1346 ///
1347 /// The leading dimensions of the index tensor give the result tensor its
1348 /// leading dimensions.
1349 /// The trailing dimensions of the result tensor are obtained from the source
1350 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1351 /// rankedReduced is false), or skipping them (otherwise).
1352 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1353  RankedTensorType indicesType,
1354  ArrayRef<int64_t> gatherDims,
1355  bool rankReduced) {
1356  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1357  resultShape.reserve(resultShape.size() + sourceType.getRank());
1358  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1359  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1360  if (!rankReduced)
1361  resultShape.push_back(1);
1362  continue;
1363  }
1364  resultShape.push_back(sourceType.getDimSize(idx));
1365  }
1366  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1367 }
1368 
1369 static LogicalResult
1371  ArrayRef<int64_t> indices, int64_t rank,
1372  StringRef gatherOrScatter, StringRef sourceOrDest) {
1373  if (dims.empty())
1374  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1375 
1376  int64_t numGatherDims = dims.size();
1377  if (numGatherDims > rank)
1378  return op->emitOpError(gatherOrScatter)
1379  << "_dims overflow " << sourceOrDest << " rank";
1380  if (indices.empty() || indices.back() != numGatherDims)
1381  return op->emitOpError(gatherOrScatter)
1382  << "_dims length must match the size of last dimension of indices";
1383  for (int64_t val : dims) {
1384  if (val < 0)
1385  return op->emitOpError(gatherOrScatter)
1386  << "_dims value must be non-negative";
1387  if (val >= rank)
1388  return op->emitOpError(gatherOrScatter)
1389  << "_dims value must be smaller than " << sourceOrDest << " rank";
1390  }
1391  for (int64_t i = 1; i < numGatherDims; ++i) {
1392  if (dims[i - 1] >= dims[i])
1393  return op->emitOpError(gatherOrScatter)
1394  << "_dims values must be strictly increasing";
1395  }
1396  return success();
1397 }
1398 
1399 LogicalResult GatherOp::verify() {
1400  int64_t sourceRank = getSourceType().getRank();
1401  ArrayRef<int64_t> gatherDims = getGatherDims();
1402  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1403  getIndicesType().getShape(), sourceRank,
1404  "gather", "source")))
1405  return failure();
1406 
1407  RankedTensorType expectedResultType = GatherOp::inferResultType(
1408  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1409  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1410  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1411  if (getResultType() != expectedResultType &&
1412  getResultType() != expectedRankReducedResultType) {
1413  return emitOpError("result type "
1414  "mismatch: "
1415  "expected ")
1416  << expectedResultType << " or its rank-reduced variant "
1417  << expectedRankReducedResultType << " (got: " << getResultType()
1418  << ")";
1419  }
1420 
1421  return success();
1422 }
1423 
1424 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1425  if (OpFoldResult reshapedSource = reshapeConstantSource(
1426  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1427  getResult().getType()))
1428  return reshapedSource;
1429  return {};
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // InsertOp
1434 //===----------------------------------------------------------------------===//
1435 
1436 void InsertOp::getAsmResultNames(
1437  function_ref<void(Value, StringRef)> setNameFn) {
1438  setNameFn(getResult(), "inserted");
1439 }
1440 
1441 LogicalResult InsertOp::verify() {
1442  // Verify the # indices match if we have a ranked type.
1443  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1444  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1445  return emitOpError("incorrect number of indices");
1446  return success();
1447 }
1448 
1449 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1450  Attribute scalar = adaptor.getScalar();
1451  Attribute dest = adaptor.getDest();
1452  if (scalar && dest)
1453  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1454  if (scalar == splatDest.getSplatValue<Attribute>())
1455  return dest;
1456  return {};
1457 }
1458 
1459 //===----------------------------------------------------------------------===//
1460 // GenerateOp
1461 //===----------------------------------------------------------------------===//
1462 
1463 void GenerateOp::getAsmResultNames(
1464  function_ref<void(Value, StringRef)> setNameFn) {
1465  setNameFn(getResult(), "generated");
1466 }
1467 
1468 LogicalResult GenerateOp::reifyResultShapes(
1469  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1470  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1471  int idx = 0;
1472  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1473  if (getType().isDynamicDim(dim)) {
1474  reifiedReturnShapes[0][dim] = getOperand(idx++);
1475  } else {
1476  reifiedReturnShapes[0][dim] =
1477  builder.getIndexAttr(getType().getDimSize(dim));
1478  }
1479  }
1480  return success();
1481 }
1482 
1483 LogicalResult GenerateOp::verify() {
1484  // Ensure that the tensor type has as many dynamic dimensions as are
1485  // specified by the operands.
1486  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1487  if (getNumOperands() != resultType.getNumDynamicDims())
1488  return emitError("must have as many index operands as dynamic extents "
1489  "in the result type");
1490  return success();
1491 }
1492 
1493 LogicalResult GenerateOp::verifyRegions() {
1494  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1495  // Ensure that region arguments span the index space.
1496  if (!llvm::all_of(getBody().getArgumentTypes(),
1497  [](Type ty) { return ty.isIndex(); }))
1498  return emitError("all body arguments must be index");
1499  if (getBody().getNumArguments() != resultTy.getRank())
1500  return emitError("must have one body argument per input dimension");
1501 
1502  // Ensure that the region yields an element of the right type.
1503  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1504 
1505  if (yieldOp.getValue().getType() != resultTy.getElementType())
1506  return emitOpError(
1507  "body must be terminated with a `yield` operation of the tensor "
1508  "element type");
1509 
1510  return success();
1511 }
1512 
1513 void GenerateOp::build(
1514  OpBuilder &b, OperationState &result, Type resultTy,
1515  ValueRange dynamicExtents,
1516  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1517  build(b, result, resultTy, dynamicExtents);
1518 
1519  // Build and populate body.
1520  OpBuilder::InsertionGuard guard(b);
1521  Region *bodyRegion = result.regions.front().get();
1522  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1523  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1524  SmallVector<Location, 2> argumentLocs(rank, result.location);
1525  Block *bodyBlock =
1526  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1527  bodyBuilder(b, result.location, bodyBlock->getArguments());
1528 }
1529 
1530 namespace {
1531 
1532 /// Canonicalizes tensor.generate operations with a constant
1533 /// operand into the equivalent operation with the operand expressed in the
1534 /// result type, instead. We also insert a type cast to make sure that the
1535 /// resulting IR is still well-typed.
1536 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1538 
1539  LogicalResult matchAndRewrite(GenerateOp generateOp,
1540  PatternRewriter &rewriter) const final {
1541  SmallVector<Value> foldedDynamicSizes;
1542  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1543  generateOp.getType(), generateOp.getDynamicExtents(),
1544  foldedDynamicSizes);
1545 
1546  // Stop here if no dynamic size was promoted to static.
1547  if (foldedTensorType == generateOp.getType())
1548  return failure();
1549 
1550  auto loc = generateOp.getLoc();
1551  auto newOp =
1552  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1553  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1554  newOp.getBody().begin());
1555  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1556  generateOp.getType(), newOp);
1557  return success();
1558  }
1559 };
1560 
1561 /// Canonicalizes the pattern of the form
1562 ///
1563 /// %tensor = tensor.generate %x {
1564 /// ^bb0(%arg0: index):
1565 /// <computation>
1566 /// yield %1 : index
1567 /// } : tensor<?xindex>
1568 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1569 ///
1570 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1571 /// tensor.generate operation has no side-effects.
1572 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1574 
1575  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1576  PatternRewriter &rewriter) const final {
1577  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1578  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1579  return failure();
1580 
1581  IRMapping mapping;
1582  Block *body = &tensorFromElements.getBody().front();
1583  mapping.map(body->getArguments(), extract.getIndices());
1584  for (auto &op : body->without_terminator())
1585  rewriter.clone(op, mapping);
1586 
1587  auto yield = cast<YieldOp>(body->getTerminator());
1588 
1589  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1590  return success();
1591  }
1592 };
1593 
1594 } // namespace
1595 
1596 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1597  MLIRContext *context) {
1598  // TODO: Move extract pattern to tensor::ExtractOp.
1599  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // RankOp
1604 //===----------------------------------------------------------------------===//
1605 
1606 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1607  setNameFn(getResult(), "rank");
1608 }
1609 
1610 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1611  // Constant fold rank when the rank of the operand is known.
1612  auto type = getOperand().getType();
1613  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1614  if (shapedType && shapedType.hasRank())
1615  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1616  return IntegerAttr();
1617 }
1618 
1619 //===----------------------------------------------------------------------===//
1620 // ReshapeOp
1621 //===----------------------------------------------------------------------===//
1622 
1623 void ReshapeOp::getAsmResultNames(
1624  function_ref<void(Value, StringRef)> setNameFn) {
1625  setNameFn(getResult(), "reshape");
1626 }
1627 
1628 static int64_t getNumElements(ShapedType type) {
1629  int64_t numElements = 1;
1630  for (auto dim : type.getShape())
1631  numElements *= dim;
1632  return numElements;
1633 }
1634 
1635 LogicalResult ReshapeOp::verify() {
1636  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1637  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1638 
1639  if (operandType.getElementType() != resultType.getElementType())
1640  return emitOpError("element types of source and destination tensor "
1641  "types should be the same");
1642 
1643  int64_t shapeSize =
1644  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1645  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1646  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1647 
1648  if (resultRankedType) {
1649  if (operandRankedType && resultRankedType.hasStaticShape() &&
1650  operandRankedType.hasStaticShape()) {
1651  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1652  return emitOpError("source and destination tensor should have the "
1653  "same number of elements");
1654  }
1655  if (ShapedType::isDynamic(shapeSize))
1656  return emitOpError("cannot use shape operand with dynamic length to "
1657  "reshape to statically-ranked tensor type");
1658  if (shapeSize != resultRankedType.getRank())
1659  return emitOpError(
1660  "length of shape operand differs from the result's tensor rank");
1661  }
1662  return success();
1663 }
1664 
1665 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1666  if (OpFoldResult reshapedSource = reshapeConstantSource(
1667  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1668  getResult().getType()))
1669  return reshapedSource;
1670 
1671  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1672  // producer's input instead as the original tensor to reshape. This could
1673  // render such producer dead code.
1674  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1675  getSourceMutable().assign(reshapeOpProducer.getSource());
1676  return getResult();
1677  }
1678 
1679  auto source = getSource();
1680  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1681  auto resultTy = dyn_cast<RankedTensorType>(getType());
1682  if (!sourceTy || !resultTy || sourceTy != resultTy)
1683  return {};
1684 
1685  // If the source and result are both 1D tensors and have the same type, the
1686  // reshape has no effect, even if the tensor is dynamically shaped.
1687  if (sourceTy.getRank() == 1)
1688  return source;
1689 
1690  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1691  auto elements = fromElements.getElements();
1692  bool dynamicNoop =
1693  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1694  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1695  auto element = elements[id];
1696 
1697  if (auto cst = getConstantIntValue(element)) {
1698  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1699  continue;
1700  }
1701 
1702  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1703  dynamicNoop &= dimOp.getSource() == source;
1704 
1705  APSInt dim;
1706  auto cst = getConstantIntValue(dimOp.getIndex());
1707  dynamicNoop &=
1708  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1709  continue;
1710  }
1711 
1712  dynamicNoop = false;
1713  break;
1714  }
1715 
1716  if (dynamicNoop)
1717  return source;
1718  }
1719 
1720  return {};
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // Reassociative reshape ops
1725 //===----------------------------------------------------------------------===//
1726 
1727 void CollapseShapeOp::getAsmResultNames(
1728  function_ref<void(Value, StringRef)> setNameFn) {
1729  setNameFn(getResult(), "collapsed");
1730 }
1731 
1732 void ExpandShapeOp::getAsmResultNames(
1733  function_ref<void(Value, StringRef)> setNameFn) {
1734  setNameFn(getResult(), "expanded");
1735 }
1736 
1737 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1738  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1739  "invalid resultDim");
1740  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1741  if (llvm::is_contained(it.value(), resultDim))
1742  return it.index();
1743  llvm_unreachable("could not find reassociation group");
1744 }
1745 
1746 FailureOr<SmallVector<OpFoldResult>>
1747 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1748  RankedTensorType expandedType,
1749  ArrayRef<ReassociationIndices> reassociation,
1750  ArrayRef<OpFoldResult> inputShape) {
1751  std::optional<SmallVector<OpFoldResult>> outputShape =
1752  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1753  inputShape);
1754  if (!outputShape)
1755  return failure();
1756  return *outputShape;
1757 }
1758 
1759 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1760  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1761 }
1762 
1763 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1764  Type resultType, Value src,
1765  ArrayRef<ReassociationIndices> reassociation,
1766  ArrayRef<OpFoldResult> outputShape) {
1767  auto [staticOutputShape, dynamicOutputShape] =
1769  build(builder, result, cast<RankedTensorType>(resultType), src,
1770  getReassociationIndicesAttribute(builder, reassociation),
1771  dynamicOutputShape, staticOutputShape);
1772 }
1773 
1774 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1775  Type resultType, Value src,
1776  ArrayRef<ReassociationIndices> reassociation) {
1777  SmallVector<OpFoldResult> inputShape =
1778  getMixedSizes(builder, result.location, src);
1779  auto tensorResultTy = cast<RankedTensorType>(resultType);
1780  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1781  builder, result.location, tensorResultTy, reassociation, inputShape);
1782  SmallVector<OpFoldResult> outputShapeOrEmpty;
1783  if (succeeded(outputShape)) {
1784  outputShapeOrEmpty = *outputShape;
1785  }
1786  build(builder, result, tensorResultTy, src, reassociation,
1787  outputShapeOrEmpty);
1788 }
1789 
1790 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1791  return getSymbolLessAffineMaps(getReassociationExprs());
1792 }
1793 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1795  getReassociationIndices());
1796 }
1797 
1798 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1799  return getSymbolLessAffineMaps(getReassociationExprs());
1800 }
1801 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1803  getReassociationIndices());
1804 }
1805 
1806 RankedTensorType CollapseShapeOp::inferCollapsedType(
1807  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1808  return inferCollapsedType(
1810  type.getContext(), reassociation)));
1811 }
1812 
1813 /// Compute the RankedTensorType obtained by applying `reassociation` to
1814 /// `type`.
1815 RankedTensorType
1816 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1817  ArrayRef<AffineMap> reassociation) {
1818  auto shape = type.getShape();
1819  SmallVector<int64_t, 4> newShape;
1820  newShape.reserve(reassociation.size());
1821 
1822  // Use the fact that reassociation is valid to simplify the logic: only use
1823  // each map's rank.
1824  assert(isReassociationValid(reassociation) && "invalid reassociation");
1825  unsigned currentDim = 0;
1826  for (AffineMap m : reassociation) {
1827  unsigned dim = m.getNumResults();
1828  auto band = shape.slice(currentDim, dim);
1829  int64_t size = 1;
1830  if (llvm::is_contained(band, ShapedType::kDynamic))
1831  size = ShapedType::kDynamic;
1832  else
1833  for (unsigned d = 0; d < dim; ++d)
1834  size *= shape[currentDim + d];
1835  newShape.push_back(size);
1836  currentDim += dim;
1837  }
1838 
1839  return RankedTensorType::get(newShape, type.getElementType());
1840 }
1841 
1842 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1843  ArrayRef<ReassociationIndices> reassociation,
1844  ArrayRef<NamedAttribute> attrs) {
1845  auto resultType = inferCollapsedType(
1846  llvm::cast<RankedTensorType>(src.getType()),
1848  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1849  result.addAttribute(getReassociationAttrStrName(),
1850  getReassociationIndicesAttribute(b, reassociation));
1851  build(b, result, resultType, src, attrs);
1852 }
1853 
1854 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1855  TensorReshapeOp, ExpandShapeOp>::value>
1856 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1857  RankedTensorType expandedType,
1858  RankedTensorType collapsedType) {
1859  if (failed(
1860  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1861  return failure();
1862 
1863  auto maps = op.getReassociationMaps();
1864  RankedTensorType expectedType =
1865  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1866  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1867  return op.emitOpError("expected collapsed type to be ")
1868  << expectedType << ", but got " << collapsedType;
1869  return success();
1870 }
1871 
1872 LogicalResult ExpandShapeOp::verify() {
1873  auto srcType = getSrcType();
1874  auto resultType = getResultType();
1875 
1876  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1877  return emitOpError("expected number of static shape dims to be equal to "
1878  "the output rank (")
1879  << resultType.getRank() << ") but found "
1880  << getStaticOutputShape().size() << " inputs instead";
1881 
1882  if ((int64_t)getOutputShape().size() !=
1883  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1884  return emitOpError("mismatch in dynamic dims in output_shape and "
1885  "static_output_shape: static_output_shape has ")
1886  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1887  << " dynamic dims while output_shape has " << getOutputShape().size()
1888  << " values";
1889 
1890  return verifyTensorReshapeOp(*this, resultType, srcType);
1891 }
1892 
1893 LogicalResult CollapseShapeOp::verify() {
1894  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1895 }
1896 
1897 namespace {
1898 /// Reshape of a splat constant can be replaced with a constant of the result
1899 /// type.
1900 template <typename TensorReshapeOp>
1901 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1903  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1904  PatternRewriter &rewriter) const override {
1905  DenseElementsAttr attr;
1906  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1907  return failure();
1908  if (!attr || !attr.isSplat())
1909  return failure();
1911  reshapeOp.getResultType(), attr.getRawData());
1912  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1913  return success();
1914  }
1915 };
1916 
1917 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1918 template <typename TensorReshapeOp>
1919 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1920 public:
1922 
1923  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1924  PatternRewriter &rewriter) const override {
1925  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1926  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1927  return failure();
1928 
1929  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1930  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1931  return success();
1932  }
1933 };
1934 
1935 /// Reshape of a FromElements can be replaced with a FromElements of the
1936 /// result type
1937 template <typename TensorReshapeOp>
1938 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1940  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1941  PatternRewriter &rewriter) const override {
1942  auto fromElements =
1943  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1944  if (!fromElements)
1945  return failure();
1946 
1947  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1948 
1949  if (!shapedTy.hasStaticShape())
1950  return failure();
1951 
1952  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1953  fromElements.getElements());
1954  return success();
1955  }
1956 };
1957 
1958 // Fold CastOp into CollapseShapeOp when adding static information.
1959 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1961 
1962  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1963  PatternRewriter &rewriter) const override {
1964  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1965  if (!tensor::canFoldIntoConsumerOp(castOp))
1966  return failure();
1967 
1968  RankedTensorType srcType =
1969  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1970  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1971  srcType, collapseShapeOp.getReassociationMaps());
1972 
1973  if (newResultType == collapseShapeOp.getResultType()) {
1974  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1975  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1976  });
1977  } else {
1978  auto newOp = rewriter.create<CollapseShapeOp>(
1979  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1980  collapseShapeOp.getReassociation());
1981  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1982  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1983  }
1984  return success();
1985  }
1986 };
1987 
1988 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1990 
1991  LogicalResult matchAndRewrite(DimOp dimOp,
1992  PatternRewriter &rewriter) const override {
1993  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1994  if (!expandShapeOp)
1995  return failure();
1996 
1997  // Only constant dimension values are supported.
1998  std::optional<int64_t> dim = dimOp.getConstantIndex();
1999  if (!dim.has_value())
2000  return failure();
2001 
2002  // Skip static dims. These are folded to constant ops.
2003  RankedTensorType resultType = expandShapeOp.getResultType();
2004  if (!resultType.isDynamicDim(*dim))
2005  return failure();
2006 
2007  // Find reassociation group that contains this result dimension.
2008  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
2009 
2010  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
2011  // ExpandShapeOp would be ambiguous.)
2012  int64_t product = 1;
2013  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
2014  for (int64_t d : grp) {
2015  if (d != dim) {
2016  assert(!resultType.isDynamicDim(d) && "expected static dim");
2017  product *= resultType.getDimSize(d);
2018  }
2019  }
2020 
2021  // result dim size = src dim size / (product(other dims in reassoc group))
2022  Value srcDimSz =
2023  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
2024  AffineExpr expr;
2025  bindSymbols(dimOp.getContext(), expr);
2026  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
2027  dimOp, expr.floorDiv(product), srcDimSz);
2028  return success();
2029  }
2030 };
2031 
2032 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
2034 
2035  LogicalResult matchAndRewrite(DimOp dimOp,
2036  PatternRewriter &rewriter) const override {
2037  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2038  if (!collapseShapeOp)
2039  return failure();
2040 
2041  // Only constant dimension values are supported.
2042  std::optional<int64_t> dim = dimOp.getConstantIndex();
2043  if (!dim.has_value() ||
2044  dim.value() >= collapseShapeOp.getResultType().getRank())
2045  return failure();
2046 
2047  // Skip static dims. These are folded to constant ops.
2048  RankedTensorType resultType = collapseShapeOp.getResultType();
2049  if (!resultType.isDynamicDim(*dim))
2050  return failure();
2051 
2052  // Get reassociation group of the result dimension.
2053  ReassociationIndices group =
2054  collapseShapeOp.getReassociationIndices()[*dim];
2055 
2056  // result dim size = product(dims in reassoc group)
2057  SmallVector<Value> srcDimSizes;
2060  for (const auto &it : llvm::enumerate(group)) {
2061  srcDimSizes.push_back(rewriter.create<DimOp>(
2062  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2063  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
2064  product = product ? product * syms.back() : syms.back();
2065  }
2066  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
2067  srcDimSizes);
2068  return success();
2069  }
2070 };
2071 
2072 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2073 /// matching constant output_shape operands of the expand. This makes the
2074 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2075 /// propagated further.
2076 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2078 
2079  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2080  PatternRewriter &rewriter) const override {
2081  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2082  if (!canFoldIntoConsumerOp(castOp))
2083  return failure();
2084 
2085  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2087  expandOp.getReassociationIndices();
2088 
2089  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2090  SmallVector<Value> dynamicOutputShape;
2091  auto outputIt = expandOp.getOutputShape().begin();
2092 
2093  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2094  for (uint64_t outDim : innerReassoc) {
2095  if (!ShapedType::isDynamic(newOutputShape[outDim]))
2096  continue;
2097 
2098  // If the cast's src type is dynamic, don't infer any of the
2099  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2100  // least one of the expanded dimensions to be dynamic if the input is
2101  // dynamic.
2102  Value val = *outputIt;
2103  ++outputIt;
2104  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2105  dynamicOutputShape.push_back(val);
2106  continue;
2107  }
2108 
2109  APInt cst;
2110  if (matchPattern(val, m_ConstantInt(&cst))) {
2111  newOutputShape[outDim] = cst.getSExtValue();
2112  } else {
2113  dynamicOutputShape.push_back(val);
2114  }
2115  }
2116  }
2117 
2118  // Couldn't match any values, nothing to change
2119  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2120  return failure();
2121 
2122  // Calculate the input shape from the output
2123  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2124  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2125  for (auto outDim : reassoc[inDim]) {
2126  auto ofr = newOutputShape[outDim];
2127  if (ShapedType::isDynamic(ofr)) {
2128  newInputShape[inDim] = ShapedType::kDynamic;
2129  break;
2130  }
2131  newInputShape[inDim] *= ofr;
2132  }
2133  }
2134 
2135  SmallVector<OpFoldResult> outputOfr =
2136  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2137  auto inputType = RankedTensorType::get(
2138  newInputShape, expandOp.getSrcType().getElementType());
2139  auto outputType = RankedTensorType::get(
2140  newOutputShape, expandOp.getSrcType().getElementType());
2141  auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2142  expandOp.getSrc());
2143  auto newExpand = rewriter.create<ExpandShapeOp>(
2144  expandOp.getLoc(), outputType, inputCast.getResult(),
2145  expandOp.getReassociationIndices(), outputOfr);
2146  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2147  newExpand.getResult());
2148  return success();
2149  }
2150 };
2151 } // namespace
2152 
2153 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2154  MLIRContext *context) {
2155  results.add<
2158  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2159  FoldReshapeWithSplat<ExpandShapeOp>,
2160  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2161  FoldDimOfCollapseShape>(context);
2162 }
2163 
2164 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2165  MLIRContext *context) {
2166  results.add<
2168  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2169  tensor::DimOp, RankedTensorType>,
2170  FoldReshapeWithConstant<CollapseShapeOp>,
2171  FoldReshapeWithSplat<CollapseShapeOp>,
2172  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2173  context);
2174 }
2175 
2176 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2177  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2178  adaptor.getOperands());
2179 }
2180 
2181 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2182  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2183  adaptor.getOperands());
2184 }
2185 
2186 //===----------------------------------------------------------------------===//
2187 // ExtractSliceOp
2188 //===----------------------------------------------------------------------===//
2189 
2190 void ExtractSliceOp::getAsmResultNames(
2191  function_ref<void(Value, StringRef)> setNameFn) {
2192  setNameFn(getResult(), "extracted_slice");
2193 }
2194 
2195 /// An extract_slice result type can be inferred, when it is not
2196 /// rank-reduced, from the source type and the static representation of
2197 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2198 RankedTensorType ExtractSliceOp::inferResultType(
2199  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2200  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2201  // An extract_slice op may specify only a leading subset of offset/sizes/
2202  // strides in which case we complete with offset=0, sizes from memref type
2203  // and strides=1.
2204  assert(static_cast<int64_t>(staticSizes.size()) ==
2205  sourceTensorType.getRank() &&
2206  "unexpected staticSizes not equal to rank of source");
2207  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2208  sourceTensorType.getEncoding());
2209 }
2210 
2211 RankedTensorType ExtractSliceOp::inferResultType(
2212  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2214  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2215  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2216  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2217  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2218  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2219  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2220  staticSizes, staticStrides);
2221 }
2222 
2223 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2224 /// number of sizes), drop as many size 1 as needed to produce an inferred
2225 /// type with the desired rank.
2226 ///
2227 /// Note that there may be multiple ways to compute this rank-reduced type:
2228 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2229 ///
2230 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2231 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2232  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2233  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2234  ArrayRef<int64_t> strides) {
2235  // Type inferred in the absence of rank-reducing behavior.
2236  auto inferredType = llvm::cast<RankedTensorType>(
2237  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2238  int rankDiff = inferredType.getRank() - desiredResultRank;
2239  if (rankDiff > 0) {
2240  auto shape = inferredType.getShape();
2241  llvm::SmallBitVector dimsToProject =
2242  getPositionsOfShapeOne(rankDiff, shape);
2243  SmallVector<int64_t> projectedShape;
2244  // Best effort rank-reducing: drop 1s in order.
2245  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2246  if (!dimsToProject.test(pos))
2247  projectedShape.push_back(shape[pos]);
2248  inferredType =
2249  RankedTensorType::get(projectedShape, inferredType.getElementType());
2250  }
2251  return inferredType;
2252 }
2253 
2254 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2255  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2257  ArrayRef<OpFoldResult> strides) {
2258  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2259  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2260  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2261  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2262  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2263  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2264  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2265  staticStrides);
2266 }
2267 
2268 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2269 /// result type. If the type passed is nullptr, it is inferred.
2270 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2271  RankedTensorType resultType, Value source,
2272  ArrayRef<OpFoldResult> offsets,
2273  ArrayRef<OpFoldResult> sizes,
2274  ArrayRef<OpFoldResult> strides,
2275  ArrayRef<NamedAttribute> attrs) {
2276  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2277  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2278  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2279  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2280  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2281  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2282  // Structuring implementation this way avoids duplication between builders.
2283  if (!resultType) {
2284  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2285  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2286  }
2287  result.addAttributes(attrs);
2288  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2289  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2290  b.getDenseI64ArrayAttr(staticSizes),
2291  b.getDenseI64ArrayAttr(staticStrides));
2292 }
2293 
2294 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2295 /// result type.
2296 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2297  ArrayRef<OpFoldResult> offsets,
2298  ArrayRef<OpFoldResult> sizes,
2299  ArrayRef<OpFoldResult> strides,
2300  ArrayRef<NamedAttribute> attrs) {
2301  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2302 }
2303 
2304 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2305 /// a Range vector.
2306 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2307  ArrayRef<Range> ranges,
2308  ArrayRef<NamedAttribute> attrs) {
2309  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2310  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2311 }
2312 
2313 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2314 /// the type passed is nullptr, it is inferred.
2315 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2316  RankedTensorType resultType, Value source,
2317  ValueRange offsets, ValueRange sizes,
2318  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2319  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2320  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2321  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2322  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2323  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2324  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2325  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2326 }
2327 
2328 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2329 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2330  ValueRange offsets, ValueRange sizes,
2331  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2332  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2333 }
2334 
2335 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2336  Operation *op,
2337  RankedTensorType expectedType) {
2338  switch (result) {
2340  return success();
2342  return op->emitError("expected rank to be smaller or equal to ")
2343  << "the other rank. ";
2345  return op->emitError("expected type to be ")
2346  << expectedType << " or a rank-reduced version. (size mismatch) ";
2348  return op->emitError("expected element type to be ")
2349  << expectedType.getElementType();
2350  default:
2351  llvm_unreachable("unexpected extract_slice op verification result");
2352  }
2353 }
2354 
2355 /// Verify that the offsets/sizes/strides-style access into the given tensor
2356 /// is in-bounds. Only static information is verified.
2357 static LogicalResult verifyInBoundsSlice(Operation *op,
2358  RankedTensorType tensorType,
2359  ArrayRef<int64_t> staticOffsets,
2360  ArrayRef<int64_t> staticSizes,
2361  ArrayRef<int64_t> staticStrides) {
2362  for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
2363  // Nothing to verify for dynamic source dims.
2364  if (tensorType.isDynamicDim(i))
2365  continue;
2366  // Nothing to verify if the offset is dynamic.
2367  if (ShapedType::isDynamic(staticOffsets[i]))
2368  continue;
2369  if (staticOffsets[i] >= tensorType.getDimSize(i))
2370  return op->emitOpError("offset ")
2371  << i << " is out-of-bounds: " << staticOffsets[i]
2372  << " >= " << tensorType.getDimSize(i);
2373  if (ShapedType::isDynamic(staticSizes[i]) ||
2374  ShapedType::isDynamic(staticStrides[i]))
2375  continue;
2376  int64_t lastPos =
2377  staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
2378  if (lastPos >= tensorType.getDimSize(i))
2379  return op->emitOpError("slice along dimension ")
2380  << i << " runs out-of-bounds: " << lastPos
2381  << " >= " << tensorType.getDimSize(i);
2382  }
2383  return success();
2384 }
2385 
2386 /// Verifier for ExtractSliceOp.
2387 LogicalResult ExtractSliceOp::verify() {
2388  RankedTensorType sourceType = getSourceType();
2389 
2390  // Verify result type against inferred type.
2391  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2392  sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2393  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2394  if (result != SliceVerificationResult::Success)
2395  return produceSliceErrorMsg(result, *this, expectedType);
2396 
2397  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2398  // to the source tensor.
2399  return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
2400  getStaticSizes(), getStaticStrides());
2401 }
2402 
2403 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2405 }
2406 
2407 FailureOr<Value>
2408 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2409  ArrayRef<int64_t> desiredShape) {
2410  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2411  assert(sourceTensorType && "not a ranked tensor type");
2412  auto sourceShape = sourceTensorType.getShape();
2413  if (sourceShape.equals(desiredShape))
2414  return value;
2415  auto maybeRankReductionMask =
2416  mlir::computeRankReductionMask(sourceShape, desiredShape);
2417  if (!maybeRankReductionMask)
2418  return failure();
2420  b, loc, value,
2421  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2422 }
2423 
2424 LogicalResult ExtractSliceOp::reifyResultShapes(
2425  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2426  reifiedReturnShapes.resize(1);
2427  reifiedReturnShapes[0].reserve(getType().getRank());
2429  llvm::SmallBitVector droppedDims = getDroppedDims();
2430  for (const auto &size : enumerate(mixedSizes)) {
2431  if (droppedDims.test(size.index()))
2432  continue;
2433  reifiedReturnShapes[0].push_back(size.value());
2434  }
2435  return success();
2436 }
2437 
2438 namespace {
2439 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2440 /// This essentially pushes memref_cast past its consuming slice when
2441 /// `canFoldIntoConsumerOp` is true.
2442 ///
2443 /// Example:
2444 /// ```
2445 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2446 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2447 /// tensor<3x4xf32>
2448 /// ```
2449 /// is rewritten into:
2450 /// ```
2451 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2452 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2453 /// ```
2454 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2455 public:
2457 
2458  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2459  PatternRewriter &rewriter) const override {
2460  // Any constant operand, just return to let the constant folder kick in.
2461  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2462  return matchPattern(operand, matchConstantIndex());
2463  }))
2464  return failure();
2465 
2466  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2467  if (!castOp)
2468  return failure();
2469 
2470  if (!canFoldIntoConsumerOp(castOp))
2471  return failure();
2472 
2473  // Create folded extract.
2474  Location loc = sliceOp.getLoc();
2475  Value newResult = rewriter.create<ExtractSliceOp>(
2476  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2477  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2478  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2479  rewriter.replaceOp(sliceOp, newResult);
2480  return success();
2481  }
2482 };
2483 
2484 /// Slice elements from `values` into `outValues`. `counts` represents the
2485 /// numbers of elements to stride in the original values for each dimension.
2486 /// The output values can be used to construct a DenseElementsAttr.
2487 template <typename IterTy, typename ElemTy>
2488 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2489  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2490  ArrayRef<int64_t> strides,
2491  llvm::SmallVectorImpl<ElemTy> *outValues) {
2492  assert(offsets.size() == sizes.size());
2493  assert(offsets.size() == strides.size());
2494  if (offsets.empty())
2495  return;
2496 
2497  int64_t offset = offsets.front();
2498  int64_t size = sizes.front();
2499  int64_t stride = strides.front();
2500  if (offsets.size() == 1) {
2501  for (int64_t i = 0; i < size; ++i, offset += stride)
2502  outValues->push_back(*(values + offset));
2503 
2504  return;
2505  }
2506 
2507  for (int64_t i = 0; i < size; ++i, offset += stride) {
2508  auto begin = values + offset * counts.front();
2509  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2510  offsets.drop_front(), sizes.drop_front(),
2511  strides.drop_front(), outValues);
2512  }
2513 }
2514 
2515 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2516 /// folded operation might introduce more constant data; Users can control
2517 /// their heuristics by the control function.
2518 class ConstantOpExtractSliceFolder final
2519  : public OpRewritePattern<ExtractSliceOp> {
2520 public:
2522 
2523  ConstantOpExtractSliceFolder(MLIRContext *context,
2525  : OpRewritePattern<ExtractSliceOp>(context),
2526  controlFn(std::move(controlFn)) {}
2527 
2528  LogicalResult matchAndRewrite(ExtractSliceOp op,
2529  PatternRewriter &rewriter) const override {
2530  DenseElementsAttr attr;
2531  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2532  return failure();
2533 
2534  // A constant splat is handled by fold().
2535  if (attr.isSplat())
2536  return failure();
2537 
2538  // Dynamic result shape is not supported.
2539  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2540  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2541  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2542  return failure();
2543 
2544  // Customized control over the folding.
2545  if (!controlFn(op))
2546  return failure();
2547 
2548  int64_t count = sourceType.getNumElements();
2549  if (count == 0)
2550  return failure();
2551 
2552  // Check if there are any dynamic parts, which are not supported.
2553  auto offsets = op.getStaticOffsets();
2554  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2555  return failure();
2556  auto sizes = op.getStaticSizes();
2557  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2558  return failure();
2559  auto strides = op.getStaticStrides();
2560  if (llvm::is_contained(strides, ShapedType::kDynamic))
2561  return failure();
2562 
2563  // Compute the stride for each dimension.
2564  SmallVector<int64_t> counts;
2565  ArrayRef<int64_t> shape = sourceType.getShape();
2566  counts.reserve(shape.size());
2567  for (int64_t v : shape) {
2568  count = count / v;
2569  counts.push_back(count);
2570  }
2571 
2572  // New attribute constructed by the sliced values.
2573  DenseElementsAttr newAttr;
2574 
2575  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2576  SmallVector<APInt> outValues;
2577  outValues.reserve(sourceType.getNumElements());
2578  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2579  elems.begin(), counts, offsets, sizes, strides, &outValues);
2580  newAttr = DenseElementsAttr::get(resultType, outValues);
2581  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2582  SmallVector<APFloat> outValues;
2583  outValues.reserve(sourceType.getNumElements());
2584  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2585  elems.begin(), counts, offsets, sizes, strides, &outValues);
2586  newAttr = DenseElementsAttr::get(resultType, outValues);
2587  }
2588 
2589  if (newAttr) {
2590  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2591  return success();
2592  }
2593 
2594  return failure();
2595  }
2596 
2597 private:
2598  /// This additionally controls whether the fold happens or not. Users can
2599  /// impose their heuristics in the function.
2601 };
2602 
2603 } // namespace
2604 
2607  const ControlConstantExtractSliceFusionFn &controlFn) {
2608  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2609 }
2610 
2611 /// Return the canonical type of the result of an extract_slice op.
2613  RankedTensorType operator()(ExtractSliceOp op,
2614  ArrayRef<OpFoldResult> mixedOffsets,
2615  ArrayRef<OpFoldResult> mixedSizes,
2616  ArrayRef<OpFoldResult> mixedStrides) {
2617  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2618  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2619  mixedStrides);
2620  }
2621 };
2622 
2623 /// A canonicalizer wrapper to replace ExtractSliceOps.
2625  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2626  ExtractSliceOp newOp) {
2627  Value replacement = newOp.getResult();
2628  if (replacement.getType() != op.getType())
2629  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2630  replacement);
2631  rewriter.replaceOp(op, replacement);
2632  }
2633 };
2634 
2635 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2636  MLIRContext *context) {
2637  results.add<
2640  ExtractSliceOpCastFolder>(context);
2641 }
2642 
2643 //
2644 static LogicalResult
2645 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2646  ShapedType shapedType) {
2647  OpBuilder b(op.getContext());
2648  for (OpFoldResult ofr : op.getMixedOffsets())
2649  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2650  return failure();
2651  // Rank-reducing noops only need to inspect the leading dimensions:
2652  // llvm::zip is appropriate.
2653  auto shape = shapedType.getShape();
2654  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2655  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2656  return failure();
2657  for (OpFoldResult ofr : op.getMixedStrides())
2658  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2659  return failure();
2660  return success();
2661 }
2662 
2663 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2664 /// slice, we can return the InsertSliceOp's source directly.
2665 // TODO: This only checks the immediate producer; extend to go up the
2666 // insert/extract chain if the slices are disjoint.
2667 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2668  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2669 
2670  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2671  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2672  insertOp.isSameAs(extractOp, isSame))
2673  return insertOp.getSource();
2674 
2675  return {};
2676 }
2677 
2678 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2679  if (OpFoldResult reshapedSource = reshapeConstantSource(
2680  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2681  getResult().getType()))
2682  return reshapedSource;
2683  if (getSourceType() == getType() &&
2685  return this->getSource();
2686  if (Value slice = foldExtractAfterInsertSlice(*this))
2687  return slice;
2688 
2689  return OpFoldResult();
2690 }
2691 
2693  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2694  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2695  unsigned rank = rankedTensorType.getRank();
2696  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2697  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2698  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2699  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2700  offsets, sizes, strides);
2701 }
2702 
2703 //===----------------------------------------------------------------------===//
2704 // InsertSliceOp
2705 //===----------------------------------------------------------------------===//
2706 
2707 void InsertSliceOp::getAsmResultNames(
2708  function_ref<void(Value, StringRef)> setNameFn) {
2709  setNameFn(getResult(), "inserted_slice");
2710 }
2711 
2712 // Build a InsertSliceOp with mixed static and dynamic entries.
2713 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2714  Value dest, ArrayRef<OpFoldResult> offsets,
2715  ArrayRef<OpFoldResult> sizes,
2716  ArrayRef<OpFoldResult> strides,
2717  ArrayRef<NamedAttribute> attrs) {
2718  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2719  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2720  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2721  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2722  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2723  result.addAttributes(attrs);
2724  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2725  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2726  b.getDenseI64ArrayAttr(staticSizes),
2727  b.getDenseI64ArrayAttr(staticStrides));
2728 }
2729 
2730 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2731 /// Range vector.
2732 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2733  Value dest, ArrayRef<Range> ranges,
2734  ArrayRef<NamedAttribute> attrs) {
2735  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2736  build(b, result, source, dest, offsets, sizes, strides, attrs);
2737 }
2738 
2739 // Build a InsertSliceOp with dynamic entries.
2740 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2741  Value dest, ValueRange offsets, ValueRange sizes,
2742  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2743  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2744  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2745  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2746  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2747  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2748  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2749  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2750 }
2751 
2752 /// Rank-reducing type verification for both InsertSliceOp and
2753 /// ParallelInsertSliceOp.
2755  RankedTensorType srcType, RankedTensorType dstType,
2756  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2757  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2758  // insert_slice is the inverse of extract_slice, use the same type
2759  // inference.
2760  RankedTensorType expected = ExtractSliceOp::inferResultType(
2761  dstType, staticOffsets, staticSizes, staticStrides);
2762  if (expectedType)
2763  *expectedType = expected;
2764  return isRankReducedType(expected, srcType);
2765 }
2766 
2767 /// Verifier for InsertSliceOp.
2768 LogicalResult InsertSliceOp::verify() {
2769  // Verify result type against inferred type.
2770  RankedTensorType expectedType;
2771  SliceVerificationResult result =
2772  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2773  getStaticSizes(), getStaticStrides(), &expectedType);
2774  if (result != SliceVerificationResult::Success)
2775  return produceSliceErrorMsg(result, *this, expectedType);
2776 
2777  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2778  // to the source tensor.
2779  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
2780  getStaticSizes(), getStaticStrides());
2781 }
2782 
2783 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2784 /// can mutate the second InsertSliceOp's destination to the first one's.
2785 ///
2786 /// Example:
2787 ///
2788 /// ```mlir
2789 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2790 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2791 /// ```
2792 ///
2793 /// folds into:
2794 ///
2795 /// ```mlir
2796 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2797 /// ```
2798 ///
2799 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2800 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2801  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2802 
2803  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2804  if (!prevInsertOp ||
2805  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2806  !prevInsertOp.isSameAs(insertOp, isSame))
2807  return failure();
2808 
2809  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2810  return success();
2811 }
2812 
2813 /// Folds round-trip extract/insert slice op pairs.
2814 /// Example:
2815 /// ```mlir
2816 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2817 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2818 /// ```
2819 /// can be folded into %val.
2820 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2821  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2822 
2823  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2824  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2825  !extractOp.isSameAs(insertOp, isSame))
2826  return nullptr;
2827 
2828  return extractOp.getSource();
2829 }
2830 
2831 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2832  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2833  getSourceType() == getType() &&
2835  return this->getSource();
2836  if (succeeded(foldInsertAfterInsertSlice(*this)))
2837  return getResult();
2838  if (auto result = foldInsertAfterExtractSlice(*this))
2839  return result;
2840  if (llvm::any_of(getMixedSizes(),
2841  [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2842  return getDest();
2843  return OpFoldResult();
2844 }
2845 
2846 LogicalResult InsertSliceOp::reifyResultShapes(
2847  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2848  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2849  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2850  return success();
2851 }
2852 
2853 namespace {
2854 /// Pattern to rewrite a insert_slice op with constant arguments.
2855 ///
2856 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2857 template <typename InsertOpTy>
2858 class InsertSliceOpConstantArgumentFolder final
2859  : public OpRewritePattern<InsertOpTy> {
2860 public:
2862 
2863  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2864  PatternRewriter &rewriter) const override {
2865  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2866  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2867  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2868 
2869  // No constant operands were folded, just return;
2870  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2871  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2872  failed(foldDynamicStrideList(mixedStrides)))
2873  return failure();
2874 
2875  // Create the new op in canonical form.
2876  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2877  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2878  mixedOffsets, mixedSizes, mixedStrides);
2879  Value toInsert = insertSliceOp.getSource();
2880  if (sourceType != insertSliceOp.getSourceType()) {
2881  OpBuilder::InsertionGuard g(rewriter);
2882  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2883  // is that the insertion point is just before the ParallelCombiningOp in
2884  // the parallel case.
2885  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2886  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2887  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2888  sourceType, toInsert);
2889  }
2890  rewriter.replaceOpWithNewOp<InsertOpTy>(
2891  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2892  mixedSizes, mixedStrides);
2893  return success();
2894  }
2895 };
2896 
2897 /// Fold tensor_casts with insert_slice operations. If the source or
2898 /// destination tensor is a tensor_cast that removes static type information,
2899 /// the cast is folded into the insert_slice operation. E.g.:
2900 ///
2901 /// ```mlir
2902 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2903 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2904 /// ```
2905 ///
2906 /// folds into:
2907 ///
2908 /// ```mlir
2909 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2910 /// ```
2911 ///
2912 /// Note: When folding a cast on the destination tensor, the result of the
2913 /// insert_slice operation is casted to ensure that the type of the result did
2914 /// not change.
2915 ///
2916 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2917 template <typename InsertOpTy>
2918 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2920 
2921  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2922  PatternRewriter &rewriter) const override {
2923  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2924  return matchPattern(operand, matchConstantIndex());
2925  }))
2926  return failure();
2927 
2928  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2929  auto castOp = v.getDefiningOp<tensor::CastOp>();
2930  if (!castOp || !canFoldIntoConsumerOp(castOp))
2931  return std::nullopt;
2932  return castOp.getSource();
2933  };
2934  std::optional<Value> sourceCastSource =
2935  getSourceOfCastOp(insertSliceOp.getSource());
2936  std::optional<Value> destCastSource =
2937  getSourceOfCastOp(insertSliceOp.getDest());
2938  if (!sourceCastSource && !destCastSource)
2939  return failure();
2940 
2941  auto src =
2942  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2943  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2944  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2945  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2946  if (!srcType || !dstType)
2947  return failure();
2948 
2949  // The tensor.cast source could have additional static information not seen
2950  // in the insert slice op static sizes, so we ignore dynamic dims when
2951  // computing the rank reduction mask.
2952  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2953  auto rankReductionMask = computeRankReductionMask(
2954  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2955  if (!rankReductionMask.has_value())
2956  return failure();
2957  // Replace dimensions in the insert slice op with corresponding static dims
2958  // from the cast source type. If the insert slice sizes have static dims
2959  // that are not static in the tensor.cast source (i.e., when the cast op
2960  // casts a dynamic dim to static), the dim should not be replaced, and the
2961  // pattern will fail later in `verifyInsertSliceOp`.
2962  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2963  int64_t rankReducedIdx = 0;
2964  for (auto [idx, size] : enumerate(staticSizes)) {
2965  if (!rankReductionMask.value().contains(idx) &&
2966  !srcType.isDynamicDim(rankReducedIdx)) {
2967  mixedSizes[idx] = getAsIndexOpFoldResult(
2968  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2969  size = srcType.getDimSize(rankReducedIdx++);
2970  }
2971  }
2972  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2973  staticSizes, insertSliceOp.getStaticStrides()) !=
2975  return failure();
2976 
2977  Operation *replacement = rewriter.create<InsertOpTy>(
2978  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2979  mixedSizes, insertSliceOp.getMixedStrides());
2980 
2981  // In the parallel case there is no result and so nothing to cast.
2982  bool isParallelInsert =
2983  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2984  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2985  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2986  insertSliceOp.getDestType(),
2987  replacement->getResult(0));
2988  }
2989  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2990  return success();
2991  }
2992 };
2993 
2994 /// If additional static type information can be deduced from a insert_slice's
2995 /// size operands, insert an explicit cast of the op's source operand. This
2996 /// enables other canonicalization patterns that are matching for tensor_cast
2997 /// ops such as `ForOpTensorCastFolder` in SCF.
2998 ///
2999 /// Example:
3000 ///
3001 /// ```mlir
3002 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3003 /// : tensor<?x?xf32> into ...
3004 /// ```
3005 ///
3006 /// folds into:
3007 ///
3008 /// ```mlir
3009 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3010 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3011 /// : tensor<64x64xf32> into ...
3012 /// ```
3013 ///
3014 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3015 template <typename InsertOpTy>
3016 struct InsertSliceOpSourceCastInserter final
3017  : public OpRewritePattern<InsertOpTy> {
3019 
3020  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3021  PatternRewriter &rewriter) const override {
3022  RankedTensorType srcType = insertSliceOp.getSourceType();
3023  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3024  return failure();
3025  SmallVector<int64_t> newSrcShape(srcType.getShape());
3026  for (int64_t i = 0; i < srcType.getRank(); ++i) {
3027  if (std::optional<int64_t> constInt =
3028  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3029  // Bail on invalid IR.
3030  if (*constInt < 0)
3031  return failure();
3032  newSrcShape[i] = *constInt;
3033  }
3034  }
3035  if (!hasValidSizesOffsets(newSrcShape))
3036  return failure();
3037 
3038  RankedTensorType newSrcType = RankedTensorType::get(
3039  newSrcShape, srcType.getElementType(), srcType.getEncoding());
3040  if (srcType == newSrcType ||
3041  !preservesStaticInformation(srcType, newSrcType) ||
3042  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3043  return failure();
3044 
3045  // newSrcType is:
3046  // 1) Different from srcType.
3047  // 2) "More static" than srcType.
3048  // 3) Cast-compatible with srcType.
3049  // Insert the cast.
3050  OpBuilder::InsertionGuard g(rewriter);
3051  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3052  // that the insertion point is just before the ParallelCombiningOp in the
3053  // parallel case.
3054  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3055  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3056  Value cast = rewriter.create<tensor::CastOp>(
3057  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3058  rewriter.replaceOpWithNewOp<InsertOpTy>(
3059  insertSliceOp, cast, insertSliceOp.getDest(),
3060  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3061  insertSliceOp.getMixedStrides());
3062  return success();
3063  }
3064 };
3065 } // namespace
3066 
3067 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3068  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3069 }
3070 
3071 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3072  MLIRContext *context) {
3073  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3074  InsertSliceOpCastFolder<InsertSliceOp>,
3075  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3076 }
3077 
3079  Location loc,
3080  Value tensor,
3081  Value dest) {
3082  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3083  unsigned rank = rankedTensorType.getRank();
3084  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3085  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3086  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3087  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3088  sizes, strides);
3089 }
3090 
3091 //===----------------------------------------------------------------------===//
3092 // PadOp
3093 //===----------------------------------------------------------------------===//
3094 
3095 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3096  setNameFn(getResult(), "padded");
3097 }
3098 
3099 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3100 // supports optional types.
3101 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3102  Type typeToInfer, Type typeToInferFrom) {}
3103 
3104 ParseResult
3106  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3107  Type &typeToInfer, Type typeToInferFrom) {
3108  if (optOperand)
3109  typeToInfer = typeToInferFrom;
3110  return success();
3111 }
3112 
3113 LogicalResult PadOp::verify() {
3114  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3115  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3116  auto expectedType =
3117  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3118  if (!expectedType) {
3119  return emitError("failed to infer expectedType from sourceType ")
3120  << sourceType << ", specified resultType is " << resultType;
3121  }
3122  if (resultType.getRank() != expectedType.getRank()) {
3123  return emitError("specified type ")
3124  << resultType << " does not match the inferred type "
3125  << expectedType;
3126  }
3127  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3128  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3129  continue;
3130  if (expectedType.isDynamicDim(i))
3131  continue;
3132  return emitError("specified type ")
3133  << resultType << " does not match the inferred type "
3134  << expectedType;
3135  }
3136 
3137  return success();
3138 }
3139 
3140 LogicalResult PadOp::verifyRegions() {
3141  auto &region = getRegion();
3142  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3143  Block &block = region.front();
3144  if (block.getNumArguments() != rank)
3145  return emitError("expected the block to have ") << rank << " arguments";
3146 
3147  // Note: the number and type of yield values are checked in the YieldOp.
3148  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3149  if (!en.value().isIndex())
3150  return emitOpError("expected block argument ")
3151  << (en.index() + 1) << " to be an index";
3152  }
3153 
3154  // Ensure that the region yields an element of the right type.
3155  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3156  if (yieldOp.getValue().getType() !=
3157  llvm::cast<ShapedType>(getType()).getElementType())
3158  return emitOpError("expected yield type to match shape element type");
3159 
3160  return success();
3161 }
3162 
3163 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3164  ArrayRef<int64_t> staticLow,
3165  ArrayRef<int64_t> staticHigh,
3166  ArrayRef<int64_t> resultShape) {
3167  unsigned rank = sourceType.getRank();
3168  if (staticLow.size() != rank)
3169  return RankedTensorType();
3170  if (staticHigh.size() != rank)
3171  return RankedTensorType();
3172  if (!resultShape.empty() && resultShape.size() != rank)
3173  return RankedTensorType();
3174 
3175  SmallVector<int64_t, 4> inferredShape;
3176  for (auto i : llvm::seq<unsigned>(0, rank)) {
3177  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3178  staticHigh[i] == ShapedType::kDynamic) {
3179  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3180  : resultShape[i]);
3181  } else {
3182  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3183  assert((resultShape.empty() || size == resultShape[i] ||
3184  resultShape[i] == ShapedType::kDynamic) &&
3185  "mismatch between inferred shape and result shape");
3186  inferredShape.push_back(size);
3187  }
3188  }
3189 
3190  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3191 }
3192 
3193 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3194  Value source, ArrayRef<int64_t> staticLow,
3195  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3196  bool nofold, ArrayRef<NamedAttribute> attrs) {
3197  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3198  if (!resultType)
3199  resultType = inferResultType(sourceType, staticLow, staticHigh);
3200  result.addAttributes(attrs);
3201  build(b, result, resultType, source, low, high,
3202  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3203  nofold ? b.getUnitAttr() : UnitAttr());
3204 }
3205 
3206 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3207  Value source, ValueRange low, ValueRange high, bool nofold,
3208  ArrayRef<NamedAttribute> attrs) {
3209  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3210  unsigned rank = sourceType.getRank();
3211  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3212  build(b, result, resultType, source, staticVector, staticVector, low, high,
3213  nofold, attrs);
3214 }
3215 
3216 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3217  Value source, ArrayRef<OpFoldResult> low,
3218  ArrayRef<OpFoldResult> high, bool nofold,
3219  ArrayRef<NamedAttribute> attrs) {
3220  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3221  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3222  SmallVector<int64_t, 4> staticLow, staticHigh;
3223  // staticLow and staticHigh have full information of the padding config.
3224  // This will grow staticLow and staticHigh with 1 value. If the config is
3225  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3226  // value as well.
3227  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3228  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3229  if (!resultType) {
3230  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3231  }
3232  assert(llvm::isa<RankedTensorType>(resultType));
3233  result.addAttributes(attrs);
3234  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3235  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3236  nofold ? b.getUnitAttr() : UnitAttr());
3237 }
3238 
3239 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3240  Value source, ArrayRef<OpFoldResult> low,
3241  ArrayRef<OpFoldResult> high, Value constantPadValue,
3242  bool nofold, ArrayRef<NamedAttribute> attrs) {
3243  build(b, result, resultType, source, low, high, nofold, attrs);
3244 
3245  // Add a region and a block to yield the pad value.
3246  Region *region = result.regions[0].get();
3247  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3248  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3249  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3250 
3251  // `builder.createBlock` changes the insertion point within the block. Create
3252  // a guard to reset the insertion point of the builder after it is destroyed.
3253  OpBuilder::InsertionGuard guard(b);
3254  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3255  b.create<tensor::YieldOp>(result.location, constantPadValue);
3256 }
3257 
3258 llvm::SmallBitVector PadOp::getPaddedDims() {
3259  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3260  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3261  for (const auto &en : enumerate(paddingWidths))
3262  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3263  paddedDims.set(en.index());
3264  };
3265  extractPaddedDims(getMixedLowPad());
3266  extractPaddedDims(getMixedHighPad());
3267  return paddedDims;
3268 }
3269 
3270 namespace {
3271 // Folds tensor.pad when padding is static zeros and the attribute
3272 // doesn't request otherwise.
3273 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3275 
3276  LogicalResult matchAndRewrite(PadOp padTensorOp,
3277  PatternRewriter &rewriter) const override {
3278  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3279  return failure();
3280  if (padTensorOp.getNofold())
3281  return failure();
3282  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3283  padTensorOp, padTensorOp.getResult().getType(),
3284  padTensorOp.getSource());
3285  return success();
3286  }
3287 };
3288 
3289 // Fold CastOp into PadOp when adding static information.
3290 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3292 
3293  LogicalResult matchAndRewrite(PadOp padTensorOp,
3294  PatternRewriter &rewriter) const override {
3295  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3296  if (!tensor::canFoldIntoConsumerOp(castOp))
3297  return failure();
3298 
3299  auto newResultType = PadOp::inferResultType(
3300  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3301  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3302  padTensorOp.getResultType().getShape());
3303 
3304  if (newResultType == padTensorOp.getResultType()) {
3305  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3306  padTensorOp.getSourceMutable().assign(castOp.getSource());
3307  });
3308  } else {
3309  auto newOp = rewriter.create<PadOp>(
3310  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3311  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3312  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3313  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3314  IRMapping mapper;
3315  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3316 
3317  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3318  padTensorOp, padTensorOp.getResultType(), newOp);
3319  }
3320  return success();
3321  }
3322 };
3323 
3324 // Fold CastOp using the result of PadOp back into the latter if it adds
3325 // static information.
3326 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3328 
3329  LogicalResult matchAndRewrite(PadOp padTensorOp,
3330  PatternRewriter &rewriter) const override {
3331  if (!padTensorOp.getResult().hasOneUse())
3332  return failure();
3333  auto tensorCastOp =
3334  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3335  if (!tensorCastOp)
3336  return failure();
3337  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3338  tensorCastOp.getDest().getType()))
3339  return failure();
3340 
3341  auto replacementOp = rewriter.create<PadOp>(
3342  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3343  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3344  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3345  padTensorOp.getHigh(), padTensorOp.getNofold(),
3346  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3347  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3348 
3349  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3350  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3351  return success();
3352  }
3353 };
3354 
3355 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3356 /// different dimensions. The pattern applies if the following preconditions
3357 /// hold:
3358 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3359 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3360 /// 3) the tensor::PadOps perform only high-padding,
3361 /// 4) the tensor::PadOps have the same constant padding value,
3362 /// 5) the tensor::PadOps do not have common padding dimensions,
3363 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3364 /// zero-offset for every dimension.
3365 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3366 /// the
3367 /// padded source dimensions.
3368 ///
3369 /// Example:
3370 ///
3371 /// ```mlir
3372 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3373 /// : tensor<64x64xf32> to tensor<?x64xf32>
3374 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3375 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3376 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3377 /// : tensor<8x64xf32> to tensor<8x?xf32>
3378 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3379 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3380 /// ```
3381 ///
3382 /// folds into:
3383 ///
3384 /// ```mlir
3385 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3386 /// : tensor<64x64xf32> to tensor<?x?xf32>
3387 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3388 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3389 /// ```
3390 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3392 
3393  LogicalResult matchAndRewrite(PadOp padOp,
3394  PatternRewriter &rewriter) const override {
3395  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3396  if (!innerSliceOp)
3397  return failure();
3398  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3399  if (!outerPadOp || outerPadOp.getNofold())
3400  return failure();
3401  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3402  if (!outerSliceOp)
3403  return failure();
3404 
3405  // 1) Fail if the chain is rank-reducing.
3406  int64_t rank = padOp.getSourceType().getRank();
3407  if (outerSliceOp.getSourceType().getRank() != rank) {
3408  return rewriter.notifyMatchFailure(padOp,
3409  "cannot fold rank-reducing chain");
3410  }
3411 
3412  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3413  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3414  return rewriter.notifyMatchFailure(
3415  padOp, "cannot fold non-unit stride ExtractSliceOps");
3416  }
3417 
3418  // 3) Fail if the tensor::PadOps have non-zero low padding.
3419  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3420  return rewriter.notifyMatchFailure(padOp,
3421  "cannot fold PadOps with low padding");
3422  }
3423 
3424  // 4) Fail if the tensor::PadOps padding values do not match.
3425  Attribute innerAttr, outerAttr;
3426  Value innerValue = padOp.getConstantPaddingValue();
3427  Value outerValue = outerPadOp.getConstantPaddingValue();
3428  if (!innerValue || !outerValue ||
3429  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3430  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3431  innerAttr != outerAttr) {
3432  return rewriter.notifyMatchFailure(
3433  padOp, "cannot fold PadOps with different padding values");
3434  }
3435 
3436  // 5) Fail if a dimension is padded by both tensor::PadOps.
3437  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3438  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3439  if (innerDims.anyCommon(outerDims)) {
3440  return rewriter.notifyMatchFailure(
3441  padOp, "cannot fold PadOps with common padding dimensions");
3442  }
3443 
3444  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3445  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3446  // for every dimension, and use the offset the other pair. Fail if no
3447  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3448  // exists.
3449  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3450  for (auto en : enumerate(newOffsets)) {
3451  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3452  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3453  if (!innerDims.test(en.index()) &&
3454  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3455  en.value() = outerOffset;
3456  continue;
3457  }
3458  if (!outerDims.test(en.index()) &&
3459  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3460  en.value() = innerOffset;
3461  continue;
3462  }
3463  return rewriter.notifyMatchFailure(
3464  padOp, "cannot find zero-offset and zero-padding pair");
3465  }
3466 
3467  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3468  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3469  // outer tensor::PadOp and fail if the size of the inner
3470  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3471  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3472  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3473  for (auto en : enumerate(newSizes)) {
3474  if (!outerDims.test(en.index()))
3475  continue;
3476  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3477  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3478  assert(!ShapedType::isDynamic(sourceSize) &&
3479  "expected padded dimension to have a static size");
3480  if (getConstantIntValue(sliceSize) != sourceSize) {
3481  return rewriter.notifyMatchFailure(
3482  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3483  "match the size of the outer padding");
3484  }
3485  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3486  }
3487 
3488  // Combine the high paddings of the two tensor::PadOps.
3489  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3490  for (auto en : enumerate(newHighPad)) {
3491  if (innerDims.test(en.index()))
3492  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3493  if (outerDims.test(en.index()))
3494  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3495  }
3496 
3497  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3498  // the two paddings in one step.
3499  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3500  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3501  innerSliceOp.getMixedStrides());
3502  auto newPadOp = rewriter.create<PadOp>(
3503  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3504  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3505  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3506  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3507  newPadOp.getRegion().begin());
3508  rewriter.replaceOp(padOp, newPadOp.getResult());
3509  return success();
3510  }
3511 };
3512 
3513 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3515 
3516  LogicalResult matchAndRewrite(PadOp padTensorOp,
3517  PatternRewriter &rewriter) const override {
3518  Value input = padTensorOp.getSource();
3519  if (!llvm::isa<RankedTensorType>(input.getType()))
3520  return failure();
3521  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3522  auto inputRank = inputDims.size();
3523 
3524  auto oldResultType =
3525  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3526  if (!oldResultType)
3527  return failure();
3528 
3529  auto outputDims = oldResultType.getShape();
3530 
3531  // Extract the static info from the high and low operands.
3532  SmallVector<int64_t> constOperandsLow;
3533  SmallVector<Value> newLows;
3534  for (auto operand : padTensorOp.getLow()) {
3535  APSInt intOp;
3536  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3537  constOperandsLow.push_back(ShapedType::kDynamic);
3538  newLows.push_back(operand);
3539  continue;
3540  }
3541  constOperandsLow.push_back(intOp.getExtValue());
3542  }
3543  SmallVector<int64_t> constOperandsHigh;
3544  SmallVector<Value> newHighs;
3545  for (auto operand : padTensorOp.getHigh()) {
3546  APSInt intOp;
3547  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3548  constOperandsHigh.push_back(ShapedType::kDynamic);
3549  newHighs.push_back(operand);
3550  continue;
3551  }
3552  constOperandsHigh.push_back(intOp.getExtValue());
3553  }
3554 
3555  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3556  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3557 
3558  // Verify the op is well-formed.
3559  if (inputDims.size() != outputDims.size() ||
3560  inputDims.size() != constLow.size() ||
3561  inputDims.size() != constHigh.size())
3562  return failure();
3563 
3564  auto lowCount = 0;
3565  auto highCount = 0;
3566  for (size_t i = 0; i < inputRank; i++) {
3567  if (constLow[i] == ShapedType::kDynamic)
3568  constLow[i] = constOperandsLow[lowCount++];
3569  if (constHigh[i] == ShapedType::kDynamic)
3570  constHigh[i] = constOperandsHigh[highCount++];
3571  }
3572 
3573  auto staticLow = ArrayRef<int64_t>(constLow);
3574  auto staticHigh = ArrayRef<int64_t>(constHigh);
3575 
3576  // Calculate the output sizes with the static information.
3577  SmallVector<int64_t> newOutDims;
3578  for (size_t i = 0; i < inputRank; i++) {
3579  if (outputDims[i] == ShapedType::kDynamic) {
3580  newOutDims.push_back(
3581  (staticLow[i] == ShapedType::kDynamic ||
3582  staticHigh[i] == ShapedType::kDynamic ||
3583  inputDims[i] == ShapedType::kDynamic
3584  ? ShapedType::kDynamic
3585  : inputDims[i] + staticLow[i] + staticHigh[i]));
3586  } else {
3587  newOutDims.push_back(outputDims[i]);
3588  }
3589  }
3590 
3591  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3592  llvm::all_of(newOutDims,
3593  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3594  return failure();
3595 
3596  // Rewrite the op using the new static type.
3597  auto newResultType = RankedTensorType::get(
3598  newOutDims, padTensorOp.getType().getElementType());
3599  auto newOp = rewriter.create<PadOp>(
3600  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3601  newLows, newHighs, padTensorOp.getNofold(),
3602  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3603 
3604  IRMapping mapper;
3605  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3606  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3607  newOp);
3608 
3609  return success();
3610  }
3611 };
3612 
3613 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3614 ///
3615 /// Example:
3616 ///
3617 /// ```mlir
3618 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3619 /// tensor.yield %val
3620 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3621 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3622 /// tensor.yield %val
3623 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3624 /// ```
3625 ///
3626 /// folds into:
3627 ///
3628 /// ```mlir
3629 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3630 /// tensor.yield %val
3631 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3632 /// ```
3633 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3635 
3636  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3637  PatternRewriter &rewriter) const override {
3638  if (padOp.getNofold()) {
3639  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3640  }
3641 
3642  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3643  if (!producerPad || producerPad.getNofold()) {
3644  return rewriter.notifyMatchFailure(
3645  padOp, "producer is not a foldable tensor.pad op");
3646  }
3647 
3648  // Fail if the tensor::PadOps padding values do not match.
3649  Value consumerPadValue = padOp.getConstantPaddingValue();
3650  Value producerPadValue = producerPad.getConstantPaddingValue();
3651  if (!consumerPadValue || !producerPadValue ||
3652  consumerPadValue != producerPadValue) {
3653  return rewriter.notifyMatchFailure(
3654  padOp,
3655  "cannot fold PadOps with different or non-constant padding values");
3656  }
3657 
3658  Location loc = padOp.getLoc();
3659  AffineExpr d0, d1;
3660  bindDims(rewriter.getContext(), d0, d1);
3661 
3662  // Combine the low/high paddings of the two tensor::PadOps.
3663  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3664  ArrayRef<OpFoldResult> producerPaddings) {
3665  SmallVector<OpFoldResult> sumPaddings;
3666  for (auto [consumerIndex, producerIndex] :
3667  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3668  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3669  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3670  }
3671  return sumPaddings;
3672  };
3673 
3674  SmallVector<OpFoldResult> newHighPad =
3675  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3676  SmallVector<OpFoldResult> newLowPad =
3677  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3678 
3679  auto newPadOp = rewriter.create<tensor::PadOp>(
3680  padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3681  newLowPad, newHighPad, padOp.getNofold(),
3682  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3683  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3684  newPadOp.getRegion().begin());
3685  rewriter.replaceOp(padOp, newPadOp.getResult());
3686  return success();
3687  }
3688 };
3689 
3690 } // namespace
3691 
3692 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3693  MLIRContext *context) {
3694  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3695  FoldOrthogonalPaddings, FoldStaticPadding,
3696  FoldConsecutiveConstantPadding>(context);
3697 }
3698 
3699 /// Return the padding value of the PadOp if it constant. In this context,
3700 /// "constant" means an actual constant or "defined outside of the block".
3701 ///
3702 /// Values are considered constant in three cases:
3703 /// - A ConstantLike value.
3704 /// - A basic block argument from a different block.
3705 /// - A value defined outside of the block.
3706 ///
3707 /// If the padding value is not constant, an empty Value is returned.
3708 Value PadOp::getConstantPaddingValue() {
3709  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3710  if (!yieldOp)
3711  return {};
3712  Value padValue = yieldOp.getValue();
3713  // Check if yield value is a constant.
3714  if (matchPattern(padValue, m_Constant()))
3715  return padValue;
3716  // Check if yield value is defined inside the PadOp block.
3717  if (padValue.getParentBlock() == &getRegion().front())
3718  return {};
3719  // Else: Yield value defined outside of the PadOp block.
3720  return padValue;
3721 }
3722 
3723 OpFoldResult PadOp::fold(FoldAdaptor) {
3724  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3725  !getNofold())
3726  return getSource();
3727  return {};
3728 }
3729 
3730 //===----------------------------------------------------------------------===//
3731 // ParallelInsertSliceOp
3732 //===----------------------------------------------------------------------===//
3733 
3734 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3735  ParallelCombiningOpInterface parallelCombiningParent =
3736  getParallelCombiningParent();
3737  for (const auto &it :
3738  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3739  Operation &nextOp = it.value();
3740  if (&nextOp == getOperation())
3741  return parallelCombiningParent.getParentResult(it.index());
3742  }
3743  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3744 }
3745 
3746 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3747 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3748  Value source, Value dest,
3749  ArrayRef<OpFoldResult> offsets,
3750  ArrayRef<OpFoldResult> sizes,
3751  ArrayRef<OpFoldResult> strides,
3752  ArrayRef<NamedAttribute> attrs) {
3753  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3754  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3755  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3756  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3757  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3758  result.addAttributes(attrs);
3759  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3760  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3761  b.getDenseI64ArrayAttr(staticSizes),
3762  b.getDenseI64ArrayAttr(staticStrides));
3763 }
3764 
3765 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3766 /// packed into a Range vector.
3767 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3768  Value source, Value dest,
3769  ArrayRef<Range> ranges,
3770  ArrayRef<NamedAttribute> attrs) {
3771  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3772  build(b, result, source, dest, offsets, sizes, strides, attrs);
3773 }
3774 
3775 // Build a ParallelInsertSliceOp with dynamic entries.
3776 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3777  Value source, Value dest, ValueRange offsets,
3778  ValueRange sizes, ValueRange strides,
3779  ArrayRef<NamedAttribute> attrs) {
3780  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3781  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3782  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3783  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3784  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3785  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3786  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3787 }
3788 
3789 LogicalResult ParallelInsertSliceOp::verify() {
3790  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3791  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3792  << *(getOperation()->getParentOp());
3793 
3794  // Verify result type against inferred type.
3795  RankedTensorType expectedType;
3796  SliceVerificationResult result =
3797  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3798  getStaticSizes(), getStaticStrides(), &expectedType);
3799  if (result != SliceVerificationResult::Success)
3800  return produceSliceErrorMsg(result, *this, expectedType);
3801 
3802  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3803  // to the source tensor.
3804  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
3805  getStaticSizes(), getStaticStrides());
3806 }
3807 
3808 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3809  RewritePatternSet &results, MLIRContext *context) {
3810  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3811  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3812  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3813 }
3814 
3815 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3816  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3817 }
3818 
3819 //===----------------------------------------------------------------------===//
3820 // ScatterOp
3821 //===----------------------------------------------------------------------===//
3822 
3823 void ScatterOp::getAsmResultNames(
3824  function_ref<void(Value, StringRef)> setNameFn) {
3825  setNameFn(getResult(), "scatter");
3826 }
3827 
3828 LogicalResult ScatterOp::verify() {
3829  int64_t destRank = getDestType().getRank();
3830  ArrayRef<int64_t> scatterDims = getScatterDims();
3831  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3832  getIndicesType().getShape(), destRank,
3833  "scatter", "dest")))
3834  return failure();
3835 
3836  if (!getUnique())
3837  return emitOpError("requires 'unique' attribute to be set");
3838  // TODO: we could also check statically that there are fewer leading index
3839  // tensor dims than the dest dims. If this is not the case, the unique
3840  // attribute cannot be true.
3841 
3842  // Use the GatherOp::inferResultType on the `dest` type and verify the
3843  // expected type matches the source type.
3844  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3845  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3846  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3847  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3848  if (getSourceType() != expectedSourceType &&
3849  getSourceType() != expectedRankReducedSourceType) {
3850  return emitOpError("source type "
3851  "mismatch: "
3852  "expected ")
3853  << expectedSourceType << " or its rank-reduced variant "
3854  << expectedRankReducedSourceType << " (got: " << getSourceType()
3855  << ")";
3856  }
3857 
3858  return success();
3859 }
3860 
3861 //===----------------------------------------------------------------------===//
3862 // SplatOp
3863 //===----------------------------------------------------------------------===//
3864 
3865 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3866  Type aggregateType, ValueRange dynamicSizes) {
3867  build(builder, result, aggregateType, element, dynamicSizes);
3868 }
3869 
3870 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3871  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3872  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3873  build(builder, result, aggregateType, element, dynamicSizes);
3874 }
3875 
3876 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3877  ArrayRef<OpFoldResult> sizes) {
3878  SmallVector<int64_t> staticShape;
3879  SmallVector<Value> dynamicSizes;
3880  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3881  build(builder, result, element, staticShape, dynamicSizes);
3882 }
3883 
3884 void SplatOp::getAsmResultNames(
3885  function_ref<void(Value, StringRef)> setNameFn) {
3886  setNameFn(getResult(), "splat");
3887 }
3888 
3889 LogicalResult SplatOp::verify() {
3890  if (getType().getNumDynamicDims() != getDynamicSizes().size())
3891  return emitOpError("incorrect number of dynamic sizes, has ")
3892  << getDynamicSizes().size() << ", expected "
3893  << getType().getNumDynamicDims();
3894  return success();
3895 }
3896 
3897 LogicalResult
3899  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3900  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3901  unsigned ctr = 0;
3902  for (int64_t i = 0; i < getType().getRank(); ++i) {
3903  if (getType().isDynamicDim(i)) {
3904  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3905  } else {
3906  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3907  }
3908  }
3909  return success();
3910 }
3911 
3912 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3913  auto constOperand = adaptor.getInput();
3914  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3915  return {};
3916 
3917  // Do not fold if the splat is not statically shaped
3918  if (!getType().hasStaticShape())
3919  return {};
3920 
3921  // SplatElementsAttr::get treats single value for second arg as being a
3922  // splat.
3923  return SplatElementsAttr::get(getType(), {constOperand});
3924 }
3925 
3926 //===----------------------------------------------------------------------===//
3927 // Common Canonicalizers and Folders.
3928 //===----------------------------------------------------------------------===//
3929 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
3930  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
3931  // 2. Exclude DPS ops that are also LoopLike from this interface as they
3932  // might need special handling of attached regions.
3933  if (isa<InsertSliceOp>(op.getOperation()) ||
3934  isa<LoopLikeOpInterface>(op.getOperation()))
3935  return false;
3936 
3937  return hasFoldableTensorCastOperand(op);
3938 }
3939 
3940 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
3941 /// the `tensor.cast` has source that is more static than the consuming op.
3942 ///
3943 /// Example:
3944 /// ```mlir
3945 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3946 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
3947 /// ```
3948 ///
3949 /// folds into:
3950 ///
3951 /// ```mlir
3952 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
3953 /// ```
3954 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
3955 /// can add the pattern to their canonicalizers.
3957  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
3959  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
3960 
3961  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
3962  PatternRewriter &rewriter) const override {
3963 
3964  // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
3965  // for that instead.
3966  if (!foldTensorCastPrecondition(op) ||
3967  isa<linalg::RelayoutOpInterface>(*op))
3968  return failure();
3969 
3970  SmallVector<Type> newResultTypes(op->getResultTypes());
3971  SmallVector<Value> newOperands =
3972  getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
3973 
3974  // Clone op
3975  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
3976 
3977  SmallVector<Value, 4> replacements;
3978  replacements.reserve(newOp->getNumResults());
3979  for (auto [oldResult, newResult] :
3980  llvm::zip(op->getResults(), newOp->getResults())) {
3981  if (newResult.getType() != oldResult.getType()) {
3982  replacements.push_back(rewriter.create<tensor::CastOp>(
3983  op->getLoc(), oldResult.getType(), newResult));
3984  } else {
3985  replacements.push_back(newResult);
3986  }
3987  }
3988  rewriter.replaceOp(op, replacements);
3989 
3990  return success();
3991  }
3992 };
3993 
3994 //===----------------------------------------------------------------------===//
3995 // TensorDialect
3996 //===----------------------------------------------------------------------===//
3997 
3998 void TensorDialect::getCanonicalizationPatterns(
3999  RewritePatternSet &results) const {
4001 }
4002 
4003 //===----------------------------------------------------------------------===//
4004 // TableGen'd op method definitions
4005 //===----------------------------------------------------------------------===//
4006 
4007 #define GET_OP_CLASSES
4008 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:418
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1370
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2335
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3105
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2800
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2645
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2667
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1628
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2754
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:183
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3101
static LogicalResult verifyInBoundsSlice(Operation *op, RankedTensorType tensorType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides)
Verify that the offsets/sizes/strides-style access into the given tensor is in-bounds.
Definition: TensorOps.cpp:2357
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:139
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2820
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1856
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:3929
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:214
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:225
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:78
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:389
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:358
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2605
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:351
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:367
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:321
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3078
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2692
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:127
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:59
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:78
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:269
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:113
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:3957
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:3961
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2624
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2625
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2612
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2613
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:379
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)