MLIR  21.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
32 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/LogicalResult.h"
39 #include "llvm/Support/MathExtras.h"
40 #include <algorithm>
41 #include <optional>
42 #include <vector>
43 
44 using namespace mlir;
45 using namespace mlir::tensor;
46 
47 using llvm::divideCeilSigned;
48 using llvm::divideFloorSigned;
49 using llvm::mod;
50 
51 /// Materialize a single constant operation from a given attribute value with
52 /// the desired resultant type.
54  Attribute value, Type type,
55  Location loc) {
56  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
57  return op;
58  if (complex::ConstantOp::isBuildableWith(value, type))
59  return builder.create<complex::ConstantOp>(loc, type,
60  llvm::cast<ArrayAttr>(value));
61  return nullptr;
62 }
63 
65  int64_t dim) {
66  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
67  if (tensorType.isDynamicDim(dim))
68  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
69 
70  return builder.getIndexAttr(tensorType.getDimSize(dim));
71 }
72 
74  Location loc, Value value) {
75  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
77  for (int64_t i = 0; i < tensorType.getRank(); ++i)
78  result.push_back(getMixedSize(builder, loc, value, i));
79  return result;
80 }
81 
83  OpResult opResult) {
84  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
85  assert(tensorType && "expected tensor type");
86 
87  // If the op has a destination, it implements DestinationStyleOpInterface and
88  // we can query the destination operand from that interface.
89  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
90  if (destOp)
91  return destOp.getTiedOpOperand(opResult)->get();
92 
93  // Otherwise, create a new destination tensor with the same shape.
95  b.setInsertionPoint(opResult.getDefiningOp());
96 
97  // Compute sizes.
98  SmallVector<OpFoldResult> mixedSizes;
99  if (!tensorType.hasStaticShape()) {
100  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
101  ReifiedRankedShapedTypeDims reifiedShapes;
102  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
103  return failure();
104  mixedSizes = reifiedShapes[opResult.getResultNumber()];
105  } else {
106  // Static shape: Take static sizes directly.
107  for (int64_t sz : tensorType.getShape())
108  mixedSizes.push_back(b.getIndexAttr(sz));
109  }
110 
111  // Create empty tensor.
112  Value emptyTensor =
113  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
114  return emptyTensor;
115 }
116 
118  Operation *op,
119  SmallVector<Value> &result) {
120  for (OpResult opResult : op->getResults()) {
121  if (llvm::isa<TensorType>(opResult.getType())) {
122  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
123  if (failed(destination))
124  return failure();
125  result.push_back(*destination);
126  }
127  }
128  return success();
129 }
130 
132  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
133  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
134  return rtp1.getShape() == rtp2.getShape() &&
135  rtp1.getElementType() == rtp2.getElementType();
136  return false;
137  }
138  return tp1 == tp2; // default implementation
139 }
140 
141 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
142 /// rank-extending tensor.insert_slice op.
143 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
144  ArrayRef<OpFoldResult> mixedSizes) {
145  llvm::SmallBitVector droppedDims(mixedSizes.size());
146  int64_t shapePos = reducedShape.size() - 1;
147 
148  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
149  size_t idx = mixedSizes.size() - size.index() - 1;
150  // Rank-reduced dims must have a static unit dimension.
151  bool isStaticUnitSize =
152  isa<Attribute>(size.value()) &&
153  llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
154 
155  if (shapePos < 0) {
156  // There are no more dims in the reduced shape. All remaining sizes must
157  // be rank-reduced dims.
158  assert(isStaticUnitSize && "expected unit dim");
159  droppedDims.set(idx);
160  continue;
161  }
162 
163  // Dim is preserved if the size is not a static 1.
164  if (!isStaticUnitSize) {
165  --shapePos;
166  continue;
167  }
168 
169  // Dim is preserved if the reduced shape dim is also 1.
170  if (reducedShape[shapePos] == 1) {
171  --shapePos;
172  continue;
173  }
174 
175  // Otherwise: Dim is dropped.
176  droppedDims.set(idx);
177  }
178 
179  assert(shapePos < 0 && "dimension mismatch");
180  return droppedDims;
181 }
182 
183 /// Given a ranked tensor type and a range of values that defines its dynamic
184 /// dimension sizes, turn all dynamic sizes that have a constant value into
185 /// static dimension sizes.
186 static RankedTensorType
187 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
188  SmallVector<Value> &foldedDynamicSizes) {
189  SmallVector<int64_t> staticShape(type.getShape());
190  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
191  "incorrect number of dynamic sizes");
192 
193  // Compute new static and dynamic sizes.
194  unsigned ctr = 0;
195  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
196  if (type.isDynamicDim(i)) {
197  Value dynamicSize = dynamicSizes[ctr++];
198  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
199  if (cst.has_value()) {
200  // Dynamic size must be non-negative.
201  if (cst.value() < 0) {
202  foldedDynamicSizes.push_back(dynamicSize);
203  continue;
204  }
205  staticShape[i] = *cst;
206  } else {
207  foldedDynamicSizes.push_back(dynamicSize);
208  }
209  }
210  }
211 
212  return RankedTensorType::get(staticShape, type.getElementType(),
213  type.getEncoding());
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // BitcastOp
218 //===----------------------------------------------------------------------===//
219 
220 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
221  if (inputs.size() != 1 || outputs.size() != 1)
222  return false;
223  Type a = inputs.front(), b = outputs.front();
224  auto aT = dyn_cast<TensorType>(a);
225  auto bT = dyn_cast<TensorType>(b);
226  if (!aT || !bT)
227  return false;
228 
229  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
230  return false;
231 
232  return succeeded(verifyCompatibleShape(aT, bT));
233 }
234 
235 namespace {
236 
237 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
238 /// operation.
239 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
241 
242  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
243  PatternRewriter &rewriter) const final {
244  auto tensorBitcastOperand =
245  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
246  if (!tensorBitcastOperand)
247  return failure();
248 
249  auto resultType = cast<TensorType>(tensorBitcast.getType());
250  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
251  tensorBitcastOperand.getOperand());
252  return success();
253  }
254 };
255 
256 } // namespace
257 
258 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
259  MLIRContext *context) {
260  results.add<ChainedTensorBitcast>(context);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // CastOp
265 //===----------------------------------------------------------------------===//
266 
267 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
268  setNameFn(getResult(), "cast");
269 }
270 
271 /// Returns true if `target` is a ranked tensor type that preserves static
272 /// information available in the `source` ranked tensor type.
274  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
275  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
276 
277  // Requires RankedTensorType.
278  if (!sourceType || !targetType)
279  return false;
280 
281  // Requires same elemental type.
282  if (sourceType.getElementType() != targetType.getElementType())
283  return false;
284 
285  // Requires same rank.
286  if (sourceType.getRank() != targetType.getRank())
287  return false;
288 
289  // Requires same encoding.
290  if (sourceType.getEncoding() != targetType.getEncoding())
291  return false;
292 
293  // If cast is towards more static sizes along any dimension, don't fold.
294  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
295  if (!ShapedType::isDynamic(std::get<0>(t)) &&
296  ShapedType::isDynamic(std::get<1>(t)))
297  return false;
298  }
299 
300  return true;
301 }
302 
303 /// Determines whether tensor::CastOp casts to a more dynamic version of the
304 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
305 /// implement canonicalization patterns for ops in different dialects that may
306 /// consume the results of tensor.cast operations. Such foldable tensor.cast
307 /// operations are typically inserted as `slice` ops and are canonicalized,
308 /// to preserve the type compatibility of their uses.
309 ///
310 /// Returns true when all conditions are met:
311 /// 1. source and result are ranked tensors with same element type and rank.
312 /// 2. the tensor type has more static information than the result
313 ///
314 /// Example:
315 /// ```mlir
316 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
317 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
318 /// ```
319 ///
320 /// folds into:
321 ///
322 /// ```mlir
323 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
324 /// ```
326  if (!castOp)
327  return false;
328 
329  // Can fold if the source of cast has at least as much static information as
330  // its results.
331  return preservesStaticInformation(castOp.getType(),
332  castOp.getSource().getType());
333 }
334 
335 /// Determines whether the tensor::CastOp casts to a more static version of the
336 /// source tensor. This is useful to fold into a producing op and implement
337 /// canonicalization patterns with the `tensor.cast` op as the root, but
338 /// producer being from different dialects. Returns true when all conditions are
339 /// met:
340 /// 1. source and result and ranked tensors with same element type and rank.
341 /// 2. the result type has more static information than the source.
342 ///
343 /// Example:
344 /// ```mlir
345 /// %1 = producer ... : tensor<?x?xf32>
346 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
347 /// ```
348 ///
349 /// can be canonicalized to :
350 ///
351 /// ```mlir
352 /// %2 = producer ... : tensor<8x16xf32>
353 /// ```
354 /// Not all ops might be canonicalizable this way, but for those that can be,
355 /// this method provides a check that it is worth doing the canonicalization.
357  if (!castOp)
358  return false;
359  return preservesStaticInformation(castOp.getSource().getType(),
360  castOp.getType());
361 }
362 
364  return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
365  if (llvm::isa<BlockArgument>(opOperand.get()))
366  return false;
367  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
368  return castOp && canFoldIntoConsumerOp(castOp);
369  });
370 }
371 
373  DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
374  SmallVector<Value> newOperands;
375  newOperands.reserve(op->getNumOperands());
376 
377  assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
378 
379  // Assumes that the result has dpsInits followed by nonDpsInits.
380  int64_t dpsInitIdx = 0;
381  for (OpOperand &opOperand : op->getOpOperands()) {
382  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
383  bool fold = canFoldIntoConsumerOp(tensorCastOp);
384  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
385  if (op.isDpsInit(&opOperand) &&
386  !llvm::isa<MemRefType>(newOperands.back().getType()))
387  newResTy[dpsInitIdx++] = newOperands.back().getType();
388  }
389  return newOperands;
390 }
391 
392 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
393 /// that can be folded.
395  bool folded = false;
396  for (OpOperand &operand : op->getOpOperands()) {
397  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
398  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
399  operand.set(castOp.getOperand());
400  folded = true;
401  }
402  }
403  return success(folded);
404 }
405 
406 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
407  if (inputs.size() != 1 || outputs.size() != 1)
408  return false;
409  Type a = inputs.front(), b = outputs.front();
410  auto aT = llvm::dyn_cast<TensorType>(a);
411  auto bT = llvm::dyn_cast<TensorType>(b);
412  if (!aT || !bT)
413  return false;
414 
415  if (aT.getElementType() != bT.getElementType())
416  return false;
417 
418  return succeeded(verifyCompatibleShape(aT, bT));
419 }
420 
421 /// Compute a TensorType that has the joined shape knowledge of the two
422 /// given TensorTypes. The element types need to match.
424  assert(one.getElementType() == two.getElementType());
425 
426  if (!one.hasRank())
427  return two;
428  if (!two.hasRank())
429  return one;
430 
431  int64_t rank = one.getRank();
432  if (rank != two.getRank())
433  return {};
434 
436  join.reserve(rank);
437  for (int64_t i = 0; i < rank; ++i) {
438  if (one.isDynamicDim(i)) {
439  join.push_back(two.getDimSize(i));
440  continue;
441  }
442  if (two.isDynamicDim(i)) {
443  join.push_back(one.getDimSize(i));
444  continue;
445  }
446  if (one.getDimSize(i) != two.getDimSize(i))
447  return {};
448  join.push_back(one.getDimSize(i));
449  }
450  return RankedTensorType::get(join, one.getElementType());
451 }
452 
453 namespace {
454 
455 /// Replaces chains of two tensor.cast operations by a single tensor.cast
456 /// operation if doing so does not remove runtime constraints.
457 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
459 
460  LogicalResult matchAndRewrite(CastOp tensorCast,
461  PatternRewriter &rewriter) const final {
462  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
463 
464  if (!tensorCastOperand)
465  return failure();
466 
467  auto sourceType =
468  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
469  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
470  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
471 
472  // We can remove the intermediate cast if joining all three produces the
473  // same result as just joining the source and result shapes.
474  auto firstJoin =
475  joinShapes(joinShapes(sourceType, intermediateType), resultType);
476 
477  // The join might not exist if the cast sequence would fail at runtime.
478  if (!firstJoin)
479  return failure();
480 
481  // The newJoin always exists if the above join exists, it might just contain
482  // less information. If so, we cannot drop the intermediate cast, as doing
483  // so would remove runtime checks.
484  auto newJoin = joinShapes(sourceType, resultType);
485  if (firstJoin != newJoin)
486  return failure();
487 
488  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
489  tensorCastOperand.getOperand());
490  return success();
491  }
492 };
493 
494 /// Fold tensor.cast into tesor.extract_slice producer.
495 /// Example:
496 /// ```
497 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
498 /// tensor<128x512xf32> to tensor<?x512xf32>
499 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
500 /// ```
501 /// ->
502 /// ```
503 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
504 /// tensor<128x512xf32> to tensor<16x512xf32>
505 /// ```
506 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
508 
509  LogicalResult matchAndRewrite(CastOp tensorCast,
510  PatternRewriter &rewriter) const final {
511  auto extractOperand =
512  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
513 
514  // Cannot fold cast to unranked tensor.
515  auto rankedResultType =
516  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
517  if (!rankedResultType)
518  return failure();
519 
520  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
521  rankedResultType.getShape() ==
522  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
523  .getShape())
524  return failure();
525 
526  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
527  auto dimMask = computeRankReductionMask(
528  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
529  size_t dimIndex = 0;
530  for (size_t i = 0, e = sizes.size(); i < e; i++) {
531  if (dimMask && dimMask->count(i))
532  continue;
533  int64_t dim = rankedResultType.getShape()[dimIndex++];
534  if (ShapedType::isDynamic(dim))
535  continue;
536  sizes[i] = rewriter.getIndexAttr(dim);
537  }
538 
539  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
540  tensorCast, rankedResultType, extractOperand.getSource(),
541  extractOperand.getMixedOffsets(), sizes,
542  extractOperand.getMixedStrides());
543  return success();
544  }
545 };
546 
547 } // namespace
548 
549 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
550  MLIRContext *context) {
551  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // ConcatOp
556 //===----------------------------------------------------------------------===//
557 
558 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
559  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
560  auto tensorTypes =
561  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
562  return llvm::cast<RankedTensorType>(type);
563  }));
564  int64_t concatRank = tensorTypes[0].getRank();
565 
566  // The concatenation dim must be in the range [0, rank).
567  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
568 
569  SmallVector<int64_t> sizes(concatRank);
570  for (int64_t i = 0, e = concatRank; i < e; ++i) {
571  if (i == dim)
572  continue;
573  SaturatedInteger size;
574  for (auto tensorType : tensorTypes)
575  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
576  sizes[i] = size.asInteger();
577  }
578  auto concatSize = SaturatedInteger::wrap(0);
579  for (auto tensorType : tensorTypes)
580  concatSize =
581  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
582  sizes[dim] = concatSize.asInteger();
583  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
584 }
585 
586 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
587  ValueRange inputs) {
588  FailureOr<RankedTensorType> resultType =
589  inferResultType(dim, inputs.getTypes());
590  assert(succeeded(resultType) && "failed to infer concatenation result type");
591  build(builder, result, *resultType, dim, inputs);
592 }
593 
594 LogicalResult ConcatOp::verify() {
595  if (getInputs().size() < 1)
596  return emitOpError("requires at least one input");
597 
599  for (auto input : getInputs())
600  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
601 
602  RankedTensorType resultType = getResultType();
603  int64_t resultRank = getRank();
604  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
605  return type.getRank() != resultRank;
606  }))
607  return emitOpError("rank of concatenated inputs must match result rank");
608 
609  Type resultElementType = resultType.getElementType();
610  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
611  return type.getElementType() != resultElementType;
612  }))
613  return emitOpError("inputs and result element type must match");
614 
615  int64_t dim = getDim();
616  if (dim >= resultRank)
617  return emitOpError("concatenation dim must be less than the tensor rank");
618 
619  SmallVector<int64_t> sizes(resultRank);
620  for (int64_t i = 0, e = resultRank; i < e; ++i) {
621  if (i == dim)
622  continue;
623  SaturatedInteger size;
624  for (auto tensorType : inputTypes) {
625  FailureOr<SaturatedInteger> maybeSize =
626  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
627  if (failed(maybeSize))
628  return emitOpError("static concatenation size mismatch along ")
629  << "non-concatenated dimension " << i;
630  size = *maybeSize;
631  }
632  sizes[i] = size.asInteger();
633  }
634  auto concatSize = SaturatedInteger::wrap(0);
635  for (auto tensorType : inputTypes)
636  concatSize =
637  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
638  sizes[dim] = concatSize.asInteger();
639  auto inferredResultType =
640  RankedTensorType::get(sizes, inputTypes[0].getElementType());
641 
642  for (auto [inferredSize, actualSize] :
643  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
644  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
645  ShapedType::isDynamic(actualSize);
646  if (!hasDynamic && inferredSize != actualSize)
647  return emitOpError("result type ")
648  << resultType << "does not match inferred shape "
649  << inferredResultType << " static sizes";
650  }
651 
652  return success();
653 }
654 
655 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
656  size_t numInputs = getInputs().size();
657  uint64_t concatDim = getDim();
658 
660  inputShapes.reserve(numInputs);
661  SmallVector<OpFoldResult> concatOffsets;
662  concatOffsets.reserve(numInputs);
663  SmallVector<OpFoldResult> outputShape;
664 
665  AffineExpr addExpr =
666  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
667  OpFoldResult zero = builder.getIndexAttr(0);
668  Location loc = getLoc();
669  for (auto [index, input] : llvm::enumerate(getInputs())) {
670  SmallVector<OpFoldResult> inputShape =
671  tensor::getMixedSizes(builder, input.getLoc(), input);
672  if (index == 0) {
673  outputShape = inputShape;
674  concatOffsets.push_back(zero);
675  } else {
676  concatOffsets.push_back(outputShape[concatDim]);
677  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
678  builder, loc, addExpr,
679  {outputShape[concatDim], inputShape[concatDim]});
680  }
681  inputShapes.emplace_back(std::move(inputShape));
682  }
683 
684  Value replacement = builder.create<tensor::EmptyOp>(
685  loc, outputShape, getType().getElementType());
686 
687  int64_t rank = getType().getRank();
688  OpFoldResult one = builder.getIndexAttr(1);
689  SmallVector<OpFoldResult> strides(rank, one);
690  SmallVector<OpFoldResult> offsets(rank, zero);
691  for (auto [index, input] : llvm::enumerate(getInputs())) {
692  offsets[concatDim] = concatOffsets[index];
693  auto insertSlice = builder.create<tensor::InsertSliceOp>(
694  loc, input, replacement, offsets, inputShapes[index], strides);
695  replacement = insertSlice.getResult();
696  }
697  if (replacement.getType() != getType()) {
698  replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
699  }
700  return SmallVector<Value>{replacement};
701 }
702 
703 LogicalResult
705  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
706  ValueRange inputs = getInputs();
707  int64_t dim = getDim();
708  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
709 
710  Value init = inputs[0];
711  int64_t rank = getType().getRank();
712 
713  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
714 
715  // Pre-populate the result sizes with as much static information as possible
716  // from the given result type, as well as the inferred result type, otherwise
717  // use the dim sizes from the first input.
718  for (int64_t i = 0; i < rank; ++i) {
719  if (i == dim)
720  continue;
721  if (!getType().isDynamicDim(i)) {
722  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
723  } else if (!inferredResultType.isDynamicDim(i)) {
724  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
725  builder, getLoc(),
726  builder.getIndexAttr(inferredResultType.getDimSize(i)));
727  } else {
728  reifiedReturnShapes[0][i] =
729  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
730  }
731  }
732 
733  if (getType().isDynamicDim(dim)) {
734  // Take the sum of the input sizes along the concatenated dim.
735  AffineExpr sum = builder.getAffineDimExpr(0);
736  SmallVector<OpFoldResult> sizes = {
737  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
738  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
739  sum = sum + builder.getAffineDimExpr(idx + 1);
740  sizes.push_back(
741  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742  }
743  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
744  builder, getLoc(),
745  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
746  } else {
747  // If the result shape is static along the concatenated dim, use the static
748  // shape.
749  reifiedReturnShapes[0][dim] =
750  builder.getIndexAttr(getType().getDimSize(dim));
751  }
752  return success();
753 }
754 
755 void ConcatOp::getAsmResultNames(
756  function_ref<void(Value, StringRef)> setNameFn) {
757  setNameFn(getResult(), "concat");
758 }
759 
760 OpFoldResult ConcatOp::fold(FoldAdaptor) {
761  ValueRange inputs = getInputs();
762  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
763  return inputs[0];
764  return {};
765 }
766 
767 namespace {
768 /// Fold a concat op with a single input to a cast.
769 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
771 
772  LogicalResult matchAndRewrite(ConcatOp concatOp,
773  PatternRewriter &rewriter) const override {
774  if (concatOp.getInputs().size() != 1)
775  return failure();
776  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
777  concatOp.getInputs()[0]);
778  return success();
779  }
780 };
781 
782 /// Propagate static shapes into the operands of a `tensor.concat`.
783 ///
784 /// `tensor.concat` requires every operand to match on all dimensions except the
785 /// concatenation dimension. If one operand is already static in those
786 /// dimensions, the other operands may safely be refined to that same static
787 /// shape.
788 ///
789 /// Example:
790 ///
791 /// ```mlir
792 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
793 /// tensor<?x12xi32>
794 /// ```
795 /// ->
796 /// ```mlir
797 /// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
798 /// %2 = tensor.concat dim(0) %0, %cast :
799 /// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
800 /// ```
801 struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
803 
804  LogicalResult matchAndRewrite(ConcatOp concatOp,
805  PatternRewriter &rewriter) const override {
806  int64_t dim = concatOp.getDim();
807  RankedTensorType inferredResultType =
808  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
809 
810  // Find operands for which a more static shape can be inferred.
811  LogicalResult matched = failure();
812  // Inferred operand shapes are identical in every dimension except the
813  // concatenation dimension.
814  SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
815  for (auto [operandIdx, operandType] :
816  llvm::enumerate(concatOp->getOperandTypes())) {
817  // Compute inferred type for operand.
818  inferredOperandShape[dim] =
819  cast<RankedTensorType>(operandType).getDimSize(dim);
820  auto inferredOperandType = RankedTensorType::get(
821  inferredOperandShape, inferredResultType.getElementType());
822 
823  // Check if inferred type is more static.
824  if (!preservesStaticInformation(inferredOperandType, operandType)) {
825  matched = success();
826 
827  // Use refined operand type and create cast from original operand.
828  auto castOp =
829  rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
830  concatOp.getOperand(operandIdx));
831  rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
832  concatOp->setOperand(operandIdx, castOp->getResult(0));
833  });
834  }
835  }
836 
837  return matched;
838  }
839 };
840 
841 // Ensure `tensor.concat`'s result type is at least as static as can be inferred
842 // from its operand types.
843 ///
844 /// Example:
845 /// ```mlir
846 /// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
847 /// tensor<?x?xi32>
848 /// ```
849 /// ->
850 /// ```mlir
851 /// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
852 /// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
853 /// tensor<?x?xi32>
854 /// ```
855 struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
857 
858  LogicalResult matchAndRewrite(ConcatOp concatOp,
859  PatternRewriter &rewriter) const override {
860  int64_t dim = concatOp.getDim();
861  RankedTensorType inferredResultType =
862  ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863 
864  // The result type should be at least as static as inferred result type.
865  if (preservesStaticInformation(inferredResultType,
866  concatOp.getResultType())) {
867  return failure();
868  }
869 
870  auto newConcatOp = rewriter.create<ConcatOp>(
871  concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
872  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
873  newConcatOp);
874 
875  return success();
876  }
877 };
878 } // namespace
879 
880 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
881  MLIRContext *context) {
882  results
883  .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
884  context);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // DimOp
889 //===----------------------------------------------------------------------===//
890 
891 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
892  setNameFn(getResult(), "dim");
893 }
894 
895 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
896  int64_t index) {
897  auto loc = result.location;
898  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
899  build(builder, result, source, indexValue);
900 }
901 
902 std::optional<int64_t> DimOp::getConstantIndex() {
903  return getConstantIntValue(getIndex());
904 }
905 
906 Speculation::Speculatability DimOp::getSpeculatability() {
907  auto constantIndex = getConstantIndex();
908  if (!constantIndex)
910 
911  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
912  if (!rankedSourceType)
914 
915  if (rankedSourceType.getRank() <= constantIndex)
917 
919 }
920 
921 void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
922  SetIntLatticeFn setResultRange) {
923  setResultRange(getResult(),
924  intrange::inferShapedDimOpInterface(*this, argRanges[1]));
925 }
926 
927 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
928  // All forms of folding require a known index.
929  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
930  if (!index)
931  return {};
932 
933  // Folding for unranked types (UnrankedTensorType) is not supported.
934  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
935  if (!tensorType)
936  return {};
937 
938  // Out of bound indices produce undefined behavior but are still valid IR.
939  // Don't choke on them.
940  int64_t indexVal = index.getInt();
941  if (indexVal < 0 || indexVal >= tensorType.getRank())
942  return {};
943 
944  // Fold if the shape extent along the given index is known.
945  if (!tensorType.isDynamicDim(index.getInt())) {
946  Builder builder(getContext());
947  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
948  }
949 
950  Operation *definingOp = getSource().getDefiningOp();
951 
952  // Fold dim to the operand of tensor.generate.
953  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
954  auto resultType =
955  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
956  // The case where the type encodes the size of the dimension is handled
957  // above.
958  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
959 
960  // Find the operand of the fromElements that corresponds to this index.
961  auto dynExtents = fromElements.getDynamicExtents().begin();
962  for (auto dim : resultType.getShape().take_front(index.getInt()))
963  if (ShapedType::isDynamic(dim))
964  dynExtents++;
965 
966  return Value{*dynExtents};
967  }
968 
969  // The size at the given index is now known to be a dynamic size.
970  unsigned unsignedIndex = index.getValue().getZExtValue();
971 
972  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
973  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
974  // `resolve-shaped-type-result-dims` pass.
975  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
976  sliceOp.isDynamicSize(unsignedIndex)) {
977  return {sliceOp.getDynamicSize(unsignedIndex)};
978  }
979  }
980 
981  // dim(cast) -> dim
982  if (succeeded(foldTensorCast(*this)))
983  return getResult();
984 
985  return {};
986 }
987 
988 namespace {
989 /// Fold dim of a cast into the dim of the source of the tensor cast.
990 struct DimOfCastOp : public OpRewritePattern<DimOp> {
992 
993  LogicalResult matchAndRewrite(DimOp dimOp,
994  PatternRewriter &rewriter) const override {
995  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
996  if (!castOp)
997  return failure();
998  Value newSource = castOp.getOperand();
999  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
1000  return success();
1001  }
1002 };
1003 
1004 /// Fold dim of a destination passing style op into the dim of the corresponding
1005 /// init.
1006 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1008 
1009  LogicalResult matchAndRewrite(DimOp dimOp,
1010  PatternRewriter &rewriter) const override {
1011  auto source = dimOp.getSource();
1012  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1013  if (!destOp)
1014  return failure();
1015 
1016  auto resultIndex = cast<OpResult>(source).getResultNumber();
1017  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1018 
1019  rewriter.modifyOpInPlace(
1020  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1021  return success();
1022  }
1023 };
1024 
1025 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1026 /// operand.
1027 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1029 
1030  LogicalResult matchAndRewrite(DimOp dim,
1031  PatternRewriter &rewriter) const override {
1032  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1033 
1034  if (!reshape)
1035  return failure();
1036 
1037  // Since tensors are immutable we don't need to worry about where to place
1038  // the extract call
1039  rewriter.setInsertionPointAfter(dim);
1040  Location loc = dim.getLoc();
1041  Value extract =
1042  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
1043  if (extract.getType() != dim.getType())
1044  extract =
1045  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
1046  rewriter.replaceOp(dim, extract);
1047  return success();
1048  }
1049 };
1050 } // namespace
1051 
1052 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053  MLIRContext *context) {
1054  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // EmptyOp
1059 //===----------------------------------------------------------------------===//
1060 
1061 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1062  ArrayRef<int64_t> staticShape, Type elementType,
1063  Attribute encoding) {
1064  assert(none_of(staticShape, ShapedType::isDynamic) &&
1065  "expected only static sizes");
1066  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1067 }
1068 
1069 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1070  ArrayRef<int64_t> staticShape, Type elementType,
1071  ValueRange dynamicSizes, Attribute encoding) {
1072  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1073  build(builder, result, tensorType, dynamicSizes);
1074 }
1075 
1076 void EmptyOp::build(OpBuilder &builder, OperationState &result,
1077  ArrayRef<OpFoldResult> sizes, Type elementType,
1078  Attribute encoding) {
1079  SmallVector<int64_t> staticShape;
1080  SmallVector<Value> dynamicSizes;
1081  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1082  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1083 }
1084 
1085 LogicalResult EmptyOp::verify() {
1086  if (getType().getNumDynamicDims() != getDynamicSizes().size())
1087  return emitOpError("incorrect number of dynamic sizes, has ")
1088  << getDynamicSizes().size() << ", expected "
1089  << getType().getNumDynamicDims();
1090  return success();
1091 }
1092 
1093 LogicalResult
1095  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1096  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1097  unsigned ctr = 0;
1098  for (int64_t i = 0; i < getType().getRank(); ++i) {
1099  if (getType().isDynamicDim(i)) {
1100  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1101  } else {
1102  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1103  }
1104  }
1105  return success();
1106 }
1107 
1108 Value EmptyOp::getDynamicSize(unsigned idx) {
1109  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1110  unsigned ctr = 0;
1111  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1112  if (getType().isDynamicDim(i))
1113  ++ctr;
1114  return getDynamicSizes()[ctr];
1115 }
1116 
1119  unsigned ctr = 0;
1120  OpBuilder b(getContext());
1121  for (int64_t i = 0; i < getType().getRank(); ++i) {
1122  if (getType().isDynamicDim(i)) {
1123  result.push_back(getDynamicSizes()[ctr++]);
1124  } else {
1125  result.push_back(b.getIndexAttr(getType().getShape()[i]));
1126  }
1127  }
1128  return result;
1129 }
1130 
1131 namespace {
1132 /// Change the type of the result of a `tensor.empty` by making the result
1133 /// type statically sized along dimensions that in the original operation were
1134 /// defined as dynamic, but the size was defined using a `constant` op. For
1135 /// example
1136 ///
1137 /// %c5 = arith.constant 5: index
1138 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1139 ///
1140 /// to
1141 ///
1142 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1143 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1145 
1146  LogicalResult matchAndRewrite(EmptyOp op,
1147  PatternRewriter &rewriter) const override {
1148  SmallVector<Value> foldedDynamicSizes;
1149  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1150  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1151 
1152  // Stop here if no dynamic size was promoted to static.
1153  if (foldedTensorType == op.getType())
1154  return failure();
1155 
1156  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1157  foldedDynamicSizes);
1158  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1159  return success();
1160  }
1161 };
1162 
1163 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1165 
1166  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1167  PatternRewriter &rewriter) const override {
1168  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1169  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1170  if (!emptyTensorOp || !maybeConstantIndex)
1171  return failure();
1172  auto emptyTensorType = emptyTensorOp.getType();
1173  if (*maybeConstantIndex < 0 ||
1174  *maybeConstantIndex >= emptyTensorType.getRank() ||
1175  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1176  return failure();
1177  rewriter.replaceOp(dimOp,
1178  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1179  return success();
1180  }
1181 };
1182 
1183 /// Canonicalize
1184 ///
1185 /// ```mlir
1186 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1187 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1188 /// ```
1189 ///
1190 /// into
1191 ///
1192 /// ```mlir
1193 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1194 /// ```
1195 ///
1196 /// This assumes the input program is correct in terms of its shape. So it is
1197 /// safe to assume that `%d0` is in fact 4.
1198 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1200 
1201  LogicalResult matchAndRewrite(CastOp castOp,
1202  PatternRewriter &rewriter) const override {
1203  if (!canFoldIntoProducerOp(castOp))
1204  return failure();
1205  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1206  if (!producer)
1207  return failure();
1208 
1209  auto resultType =
1210  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1211  ArrayRef<int64_t> resultShape = resultType.getShape();
1212  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1213  SmallVector<OpFoldResult> newMixedSizes;
1214  newMixedSizes.reserve(currMixedSizes.size());
1215  assert(resultShape.size() == currMixedSizes.size() &&
1216  "mismatch in result shape and sizes of empty op");
1217  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1218  int64_t newDim = std::get<0>(it);
1219  OpFoldResult currDim = std::get<1>(it);
1220  // Case 1: The empty tensor dim is static. Check that the tensor cast
1221  // result dim matches.
1222  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1223  if (ShapedType::isDynamic(newDim) ||
1224  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1225  // Something is off, the cast result shape cannot be more dynamic
1226  // than the empty tensor result shape (enforced by
1227  // `canFoldIntoProducer`). Abort for now.
1228  return rewriter.notifyMatchFailure(
1229  producer, "mismatch in static value of shape of empty tensor "
1230  "result and cast result");
1231  }
1232  newMixedSizes.push_back(attr);
1233  continue;
1234  }
1235 
1236  // Case 2 : The tensor cast shape is static, but empty tensor result
1237  // shape is dynamic.
1238  if (!ShapedType::isDynamic(newDim)) {
1239  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1240  continue;
1241  }
1242 
1243  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1244  // shape is dynamic. Use the dynamic value from the empty tensor op.
1245  newMixedSizes.push_back(currDim);
1246  }
1247 
1248  // TODO: Do not drop tensor encoding.
1249  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1250  resultType.getElementType());
1251  return success();
1252  }
1253 };
1254 
1255 } // namespace
1256 
1257 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1258  MLIRContext *context) {
1259  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1260  ReplaceEmptyTensorStaticShapeDims>(context);
1261 }
1262 
1263 //===----------------------------------------------------------------------===//
1264 // ExtractOp
1265 //===----------------------------------------------------------------------===//
1266 
1267 namespace {
1268 
1269 /// Canonicalizes the pattern of the form
1270 ///
1271 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1272 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1273 ///
1274 /// to
1275 ///
1276 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1277 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1279 
1280  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1281  PatternRewriter &rewriter) const final {
1282  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1283  if (!tensorCast)
1284  return failure();
1285  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1286  return failure();
1287  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1288  extract, tensorCast.getSource(), extract.getIndices());
1289  return success();
1290  }
1291 };
1292 
1293 /// Canonicalizes the pattern of the form
1294 ///
1295 /// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1296 /// tensor<12xf64>
1297 /// %extracted_element = tensor.extract %val[%c10] :
1298 /// tensor<12xf64>
1299 ///
1300 /// to
1301 ///
1302 /// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1303 struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1305 
1306  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1307  PatternRewriter &rewriter) const final {
1308  auto collapseOp =
1309  extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1310  if (!collapseOp)
1311  return failure();
1312  if (!collapseOp.getSrcType().hasStaticShape())
1313  return failure();
1314 
1315  auto sourceSizes = collapseOp.getSrcType().getShape();
1316 
1317  SmallVector<Value> indices(extractOp.getIndices().begin(),
1318  extractOp.getIndices().end());
1319  SmallVector<Value> sourceIndices;
1320  for (auto [index, group] :
1321  llvm::zip(indices, collapseOp.getReassociationIndices())) {
1322  assert(!group.empty() && "association indices groups cannot be empty");
1323  auto groupSize = group.size();
1324 
1325  if (groupSize == 1) {
1326  sourceIndices.push_back(index);
1327  continue;
1328  }
1329 
1330  SmallVector<int64_t> basis =
1331  llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1332  auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
1333  extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1334  llvm::append_range(sourceIndices, delinearize.getResults());
1335  }
1336  if (collapseOp.getReassociationIndices().empty()) {
1337  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1338  int64_t srcRank =
1339  cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1341  rewriter, extractOp.getLoc(), zeroAffineMap,
1343  for (int64_t i = 0; i < srcRank; i++) {
1344  sourceIndices.push_back(
1345  getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1346  }
1347  }
1348 
1349  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1350  extractOp, collapseOp.getSrc(), sourceIndices);
1351  return success();
1352  }
1353 };
1354 
1355 } // namespace
1356 
1357 void ExtractOp::getAsmResultNames(
1358  function_ref<void(Value, StringRef)> setNameFn) {
1359  setNameFn(getResult(), "extracted");
1360 }
1361 
1362 LogicalResult ExtractOp::verify() {
1363  // Verify the # indices match if we have a ranked type.
1364  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1365  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1366  return emitOpError("incorrect number of indices for extract_element");
1367  return success();
1368 }
1369 
1370 /// If we have an ExtractOp consuming an InsertOp with the same
1371 /// indices, we can return the InsertOp's scalar directly.
1372 // TODO: This only checks the immediate producer; extend to go up the
1373 // insert/extract chain if the slices are disjoint.
1374 static Value foldExtractAfterInsert(ExtractOp extractOp) {
1375  auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1376 
1377  auto isSame = [](Value a, Value b) {
1378  return getAsOpFoldResult(a) == getAsOpFoldResult(b);
1379  };
1380  if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1381  llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1382  return insertOp.getScalar();
1383 
1384  return {};
1385 }
1386 
1387 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1388  if (Attribute tensor = adaptor.getTensor()) {
1389  // If this is a splat elements attribute, simply return the value.
1390  // All of the elements of a splat attribute are the same.
1391  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1392  return splatTensor.getSplatValue<Attribute>();
1393 
1394  // If this is a dense resource elements attribute, return.
1395  if (isa<DenseResourceElementsAttr>(tensor))
1396  return {};
1397  }
1398 
1399  // Collect the constant indices into the tensor.
1400  SmallVector<uint64_t, 8> indices;
1401  for (Attribute indice : adaptor.getIndices()) {
1402  if (!indice || !llvm::isa<IntegerAttr>(indice))
1403  return {};
1404  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1405  }
1406 
1407  // Fold extract(from_elements(...)).
1408  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1409  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1410  auto rank = tensorType.getRank();
1411  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1412  "rank mismatch");
1413  int flatIndex = 0;
1414  int stride = 1;
1415  for (int i = rank - 1; i >= 0; --i) {
1416  flatIndex += indices[i] * stride;
1417  stride *= tensorType.getDimSize(i);
1418  }
1419  // Prevent out of bounds accesses. This can happen in invalid code that
1420  // will never execute.
1421  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1422  flatIndex < 0)
1423  return {};
1424  return fromElementsOp.getElements()[flatIndex];
1425  }
1426 
1427  // If this is an elements attribute, query the value at the given indices.
1428  if (Attribute tensor = adaptor.getTensor()) {
1429  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1430  if (elementsAttr && elementsAttr.isValidIndex(indices))
1431  return elementsAttr.getValues<Attribute>()[indices];
1432  }
1433 
1434  if (Value result = foldExtractAfterInsert(*this))
1435  return result;
1436 
1437  return {};
1438 }
1439 
1440 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1441  MLIRContext *context) {
1442  results.add<ExtractFromTensorCast>(context);
1443 }
1444 
1447  patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1448 }
1449 
1450 //===----------------------------------------------------------------------===//
1451 // FromElementsOp
1452 //===----------------------------------------------------------------------===//
1453 
1454 void FromElementsOp::getAsmResultNames(
1455  function_ref<void(Value, StringRef)> setNameFn) {
1456  setNameFn(getResult(), "from_elements");
1457 }
1458 
1459 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1460  ValueRange elements) {
1461  assert(!elements.empty() && "expected at least one element");
1462  Type resultType = RankedTensorType::get(
1463  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1464  build(builder, result, resultType, elements);
1465 }
1466 
1467 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1468  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1469  return DenseElementsAttr::get(getType(), adaptor.getElements());
1470  return {};
1471 }
1472 
1473 namespace {
1474 
1475 // Pushes the index_casts that occur before extractions to after the extract.
1476 // This minimizes type conversion in some cases and enables the extract
1477 // canonicalizer. This changes:
1478 //
1479 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1480 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1481 //
1482 // to the following:
1483 //
1484 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1485 // %cast = arith.index_cast %extract : i32 to index
1486 //
1487 // to just %element.
1488 //
1489 // Consider expanding this to a template and handle all tensor cast
1490 // operations.
1491 struct ExtractElementFromIndexCast
1492  : public OpRewritePattern<tensor::ExtractOp> {
1494 
1495  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1496  PatternRewriter &rewriter) const final {
1497  Location loc = extract.getLoc();
1498  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1499  if (!indexCast)
1500  return failure();
1501 
1502  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1503 
1504  auto newExtract = rewriter.create<tensor::ExtractOp>(
1505  loc, elementTy, indexCast.getIn(), extract.getIndices());
1506 
1507  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1508  newExtract);
1509 
1510  return success();
1511  }
1512 };
1513 
1514 } // namespace
1515 
1516 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517  MLIRContext *context) {
1518  results.add<ExtractElementFromIndexCast>(context);
1519 }
1520 
1521 //===----------------------------------------------------------------------===//
1522 // GatherOp
1523 //===----------------------------------------------------------------------===//
1524 
1525 void GatherOp::getAsmResultNames(
1526  function_ref<void(Value, StringRef)> setNameFn) {
1527  setNameFn(getResult(), "gather");
1528 }
1529 
1530 /// Return the inferred result type for a gatherOp where:
1531 /// - sourceType is the type of the source tensor gathered from
1532 /// - indicesType is the type of the indices used to gather
1533 /// - gatherDims are the dims along which the gather occurs.
1534 /// Return a full rank or ranked-reduced variant of the type depending on
1535 /// the value of rankReduced.
1536 ///
1537 /// The leading dimensions of the index tensor give the result tensor its
1538 /// leading dimensions.
1539 /// The trailing dimensions of the result tensor are obtained from the source
1540 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1541 /// rankedReduced is false), or skipping them (otherwise).
1542 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543  RankedTensorType indicesType,
1544  ArrayRef<int64_t> gatherDims,
1545  bool rankReduced) {
1546  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1547  resultShape.reserve(resultShape.size() + sourceType.getRank());
1548  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549  if (llvm::binary_search(gatherDims, idx)) {
1550  if (!rankReduced)
1551  resultShape.push_back(1);
1552  continue;
1553  }
1554  resultShape.push_back(sourceType.getDimSize(idx));
1555  }
1556  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1557 }
1558 
1559 static LogicalResult
1561  ArrayRef<int64_t> indices, int64_t rank,
1562  StringRef gatherOrScatter, StringRef sourceOrDest) {
1563  if (dims.empty())
1564  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1565 
1566  int64_t numGatherDims = dims.size();
1567  if (numGatherDims > rank)
1568  return op->emitOpError(gatherOrScatter)
1569  << "_dims overflow " << sourceOrDest << " rank";
1570  if (indices.empty() || indices.back() != numGatherDims)
1571  return op->emitOpError(gatherOrScatter)
1572  << "_dims length must match the size of last dimension of indices";
1573  for (int64_t val : dims) {
1574  if (val < 0)
1575  return op->emitOpError(gatherOrScatter)
1576  << "_dims value must be non-negative";
1577  if (val >= rank)
1578  return op->emitOpError(gatherOrScatter)
1579  << "_dims value must be smaller than " << sourceOrDest << " rank";
1580  }
1581  for (int64_t i = 1; i < numGatherDims; ++i) {
1582  if (dims[i - 1] >= dims[i])
1583  return op->emitOpError(gatherOrScatter)
1584  << "_dims values must be strictly increasing";
1585  }
1586  return success();
1587 }
1588 
1589 LogicalResult GatherOp::verify() {
1590  int64_t sourceRank = getSourceType().getRank();
1591  ArrayRef<int64_t> gatherDims = getGatherDims();
1592  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1593  getIndicesType().getShape(), sourceRank,
1594  "gather", "source")))
1595  return failure();
1596 
1597  RankedTensorType expectedResultType = GatherOp::inferResultType(
1598  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1599  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1601  if (getResultType() != expectedResultType &&
1602  getResultType() != expectedRankReducedResultType) {
1603  return emitOpError("result type "
1604  "mismatch: "
1605  "expected ")
1606  << expectedResultType << " or its rank-reduced variant "
1607  << expectedRankReducedResultType << " (got: " << getResultType()
1608  << ")";
1609  }
1610 
1611  return success();
1612 }
1613 
1614 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1615  if (OpFoldResult reshapedSource = reshapeConstantSource(
1616  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1617  getResult().getType()))
1618  return reshapedSource;
1619  return {};
1620 }
1621 
1622 //===----------------------------------------------------------------------===//
1623 // InsertOp
1624 //===----------------------------------------------------------------------===//
1625 
1626 void InsertOp::getAsmResultNames(
1627  function_ref<void(Value, StringRef)> setNameFn) {
1628  setNameFn(getResult(), "inserted");
1629 }
1630 
1631 LogicalResult InsertOp::verify() {
1632  // Verify the # indices match if we have a ranked type.
1633  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1634  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1635  return emitOpError("incorrect number of indices");
1636  return success();
1637 }
1638 
1639 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1640  Attribute scalar = adaptor.getScalar();
1641  Attribute dest = adaptor.getDest();
1642  if (scalar && dest)
1643  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1644  if (scalar == splatDest.getSplatValue<Attribute>())
1645  return dest;
1646  return {};
1647 }
1648 
1649 //===----------------------------------------------------------------------===//
1650 // GenerateOp
1651 //===----------------------------------------------------------------------===//
1652 
1653 void GenerateOp::getAsmResultNames(
1654  function_ref<void(Value, StringRef)> setNameFn) {
1655  setNameFn(getResult(), "generated");
1656 }
1657 
1658 LogicalResult GenerateOp::reifyResultShapes(
1659  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1660  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1661  int idx = 0;
1662  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1663  if (getType().isDynamicDim(dim)) {
1664  reifiedReturnShapes[0][dim] = getOperand(idx++);
1665  } else {
1666  reifiedReturnShapes[0][dim] =
1667  builder.getIndexAttr(getType().getDimSize(dim));
1668  }
1669  }
1670  return success();
1671 }
1672 
1673 LogicalResult GenerateOp::verify() {
1674  // Ensure that the tensor type has as many dynamic dimensions as are
1675  // specified by the operands.
1676  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1677  if (getNumOperands() != resultType.getNumDynamicDims())
1678  return emitError("must have as many index operands as dynamic extents "
1679  "in the result type");
1680  return success();
1681 }
1682 
1683 LogicalResult GenerateOp::verifyRegions() {
1684  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1685  // Ensure that region arguments span the index space.
1686  if (!llvm::all_of(getBody().getArgumentTypes(),
1687  [](Type ty) { return ty.isIndex(); }))
1688  return emitError("all body arguments must be index");
1689  if (getBody().getNumArguments() != resultTy.getRank())
1690  return emitError("must have one body argument per input dimension");
1691 
1692  // Ensure that the region yields an element of the right type.
1693  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1694 
1695  if (yieldOp.getValue().getType() != resultTy.getElementType())
1696  return emitOpError(
1697  "body must be terminated with a `yield` operation of the tensor "
1698  "element type");
1699 
1700  return success();
1701 }
1702 
1703 void GenerateOp::build(
1704  OpBuilder &b, OperationState &result, Type resultTy,
1705  ValueRange dynamicExtents,
1706  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1707  build(b, result, resultTy, dynamicExtents);
1708 
1709  // Build and populate body.
1710  OpBuilder::InsertionGuard guard(b);
1711  Region *bodyRegion = result.regions.front().get();
1712  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1713  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1714  SmallVector<Location, 2> argumentLocs(rank, result.location);
1715  Block *bodyBlock =
1716  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1717  bodyBuilder(b, result.location, bodyBlock->getArguments());
1718 }
1719 
1720 namespace {
1721 
1722 /// Canonicalizes tensor.generate operations with a constant
1723 /// operand into the equivalent operation with the operand expressed in the
1724 /// result type, instead. We also insert a type cast to make sure that the
1725 /// resulting IR is still well-typed.
1726 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1728 
1729  LogicalResult matchAndRewrite(GenerateOp generateOp,
1730  PatternRewriter &rewriter) const final {
1731  SmallVector<Value> foldedDynamicSizes;
1732  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1733  generateOp.getType(), generateOp.getDynamicExtents(),
1734  foldedDynamicSizes);
1735 
1736  // Stop here if no dynamic size was promoted to static.
1737  if (foldedTensorType == generateOp.getType())
1738  return failure();
1739 
1740  auto loc = generateOp.getLoc();
1741  auto newOp =
1742  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1743  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1744  newOp.getBody().begin());
1745  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1746  generateOp.getType(), newOp);
1747  return success();
1748  }
1749 };
1750 
1751 /// Canonicalizes the pattern of the form
1752 ///
1753 /// %tensor = tensor.generate %x {
1754 /// ^bb0(%arg0: index):
1755 /// <computation>
1756 /// yield %1 : index
1757 /// } : tensor<?xindex>
1758 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1759 ///
1760 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1761 /// tensor.generate operation has no side-effects.
1762 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1764 
1765  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1766  PatternRewriter &rewriter) const final {
1767  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1768  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1769  return failure();
1770 
1771  IRMapping mapping;
1772  Block *body = &tensorFromElements.getBody().front();
1773  mapping.map(body->getArguments(), extract.getIndices());
1774  for (auto &op : body->without_terminator())
1775  rewriter.clone(op, mapping);
1776 
1777  auto yield = cast<YieldOp>(body->getTerminator());
1778 
1779  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1780  return success();
1781  }
1782 };
1783 
1784 } // namespace
1785 
1786 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787  MLIRContext *context) {
1788  // TODO: Move extract pattern to tensor::ExtractOp.
1789  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // RankOp
1794 //===----------------------------------------------------------------------===//
1795 
1796 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1797  setNameFn(getResult(), "rank");
1798 }
1799 
1800 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1801  // Constant fold rank when the rank of the operand is known.
1802  auto type = getOperand().getType();
1803  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804  if (shapedType && shapedType.hasRank())
1805  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1806  return IntegerAttr();
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // ReshapeOp
1811 //===----------------------------------------------------------------------===//
1812 
1813 void ReshapeOp::getAsmResultNames(
1814  function_ref<void(Value, StringRef)> setNameFn) {
1815  setNameFn(getResult(), "reshape");
1816 }
1817 
1818 static int64_t getNumElements(ShapedType type) {
1819  int64_t numElements = 1;
1820  for (auto dim : type.getShape())
1821  numElements *= dim;
1822  return numElements;
1823 }
1824 
1825 LogicalResult ReshapeOp::verify() {
1826  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1827  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1828 
1829  if (operandType.getElementType() != resultType.getElementType())
1830  return emitOpError("element types of source and destination tensor "
1831  "types should be the same");
1832 
1833  int64_t shapeSize =
1834  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1835  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1836  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1837 
1838  if (resultRankedType) {
1839  if (operandRankedType && resultRankedType.hasStaticShape() &&
1840  operandRankedType.hasStaticShape()) {
1841  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1842  return emitOpError("source and destination tensor should have the "
1843  "same number of elements");
1844  }
1845  if (ShapedType::isDynamic(shapeSize))
1846  return emitOpError("cannot use shape operand with dynamic length to "
1847  "reshape to statically-ranked tensor type");
1848  if (shapeSize != resultRankedType.getRank())
1849  return emitOpError(
1850  "length of shape operand differs from the result's tensor rank");
1851  }
1852  return success();
1853 }
1854 
1855 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1856  if (OpFoldResult reshapedSource = reshapeConstantSource(
1857  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1858  getResult().getType()))
1859  return reshapedSource;
1860 
1861  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1862  // producer's input instead as the original tensor to reshape. This could
1863  // render such producer dead code.
1864  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1865  getSourceMutable().assign(reshapeOpProducer.getSource());
1866  return getResult();
1867  }
1868 
1869  auto source = getSource();
1870  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1871  auto resultTy = dyn_cast<RankedTensorType>(getType());
1872  if (!sourceTy || !resultTy || sourceTy != resultTy)
1873  return {};
1874 
1875  // If the source and result are both 1D tensors and have the same type, the
1876  // reshape has no effect, even if the tensor is dynamically shaped.
1877  if (sourceTy.getRank() == 1)
1878  return source;
1879 
1880  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1881  auto elements = fromElements.getElements();
1882  bool dynamicNoop =
1883  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1884  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1885  auto element = elements[id];
1886 
1887  if (auto cst = getConstantIntValue(element)) {
1888  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1889  continue;
1890  }
1891 
1892  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1893  dynamicNoop &= dimOp.getSource() == source;
1894 
1895  auto cst = getConstantIntValue(dimOp.getIndex());
1896  dynamicNoop &=
1897  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1898  continue;
1899  }
1900 
1901  dynamicNoop = false;
1902  break;
1903  }
1904 
1905  if (dynamicNoop)
1906  return source;
1907  }
1908 
1909  return {};
1910 }
1911 
1912 //===----------------------------------------------------------------------===//
1913 // Reassociative reshape ops
1914 //===----------------------------------------------------------------------===//
1915 
1916 void CollapseShapeOp::getAsmResultNames(
1917  function_ref<void(Value, StringRef)> setNameFn) {
1918  setNameFn(getResult(), "collapsed");
1919 }
1920 
1921 void ExpandShapeOp::getAsmResultNames(
1922  function_ref<void(Value, StringRef)> setNameFn) {
1923  setNameFn(getResult(), "expanded");
1924 }
1925 
1926 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928  "invalid resultDim");
1929  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1930  if (llvm::is_contained(it.value(), resultDim))
1931  return it.index();
1932  llvm_unreachable("could not find reassociation group");
1933 }
1934 
1935 FailureOr<SmallVector<OpFoldResult>>
1936 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1937  RankedTensorType expandedType,
1938  ArrayRef<ReassociationIndices> reassociation,
1939  ArrayRef<OpFoldResult> inputShape) {
1940  std::optional<SmallVector<OpFoldResult>> outputShape =
1941  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1942  inputShape);
1943  if (!outputShape)
1944  return failure();
1945  return *outputShape;
1946 }
1947 
1948 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1949  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1950 }
1951 
1952 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1953  Type resultType, Value src,
1954  ArrayRef<ReassociationIndices> reassociation,
1955  ArrayRef<OpFoldResult> outputShape) {
1956  auto [staticOutputShape, dynamicOutputShape] =
1958  build(builder, result, cast<RankedTensorType>(resultType), src,
1959  getReassociationIndicesAttribute(builder, reassociation),
1960  dynamicOutputShape, staticOutputShape);
1961 }
1962 
1963 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1964  Type resultType, Value src,
1965  ArrayRef<ReassociationIndices> reassociation) {
1966  SmallVector<OpFoldResult> inputShape =
1967  getMixedSizes(builder, result.location, src);
1968  auto tensorResultTy = cast<RankedTensorType>(resultType);
1969  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1970  builder, result.location, tensorResultTy, reassociation, inputShape);
1971  SmallVector<OpFoldResult> outputShapeOrEmpty;
1972  if (succeeded(outputShape)) {
1973  outputShapeOrEmpty = *outputShape;
1974  }
1975  build(builder, result, tensorResultTy, src, reassociation,
1976  outputShapeOrEmpty);
1977 }
1978 
1979 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1980  return getSymbolLessAffineMaps(getReassociationExprs());
1981 }
1982 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1984  getReassociationIndices());
1985 }
1986 
1987 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1988  return getSymbolLessAffineMaps(getReassociationExprs());
1989 }
1990 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1992  getReassociationIndices());
1993 }
1994 
1995 RankedTensorType CollapseShapeOp::inferCollapsedType(
1996  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1997  return inferCollapsedType(
1999  type.getContext(), reassociation)));
2000 }
2001 
2002 /// Compute the RankedTensorType obtained by applying `reassociation` to
2003 /// `type`.
2004 RankedTensorType
2005 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2006  ArrayRef<AffineMap> reassociation) {
2007  auto shape = type.getShape();
2008  SmallVector<int64_t, 4> newShape;
2009  newShape.reserve(reassociation.size());
2010 
2011  // Use the fact that reassociation is valid to simplify the logic: only use
2012  // each map's rank.
2013  assert(isReassociationValid(reassociation) && "invalid reassociation");
2014  unsigned currentDim = 0;
2015  for (AffineMap m : reassociation) {
2016  unsigned dim = m.getNumResults();
2017  auto band = shape.slice(currentDim, dim);
2018  int64_t size = 1;
2019  if (llvm::is_contained(band, ShapedType::kDynamic))
2020  size = ShapedType::kDynamic;
2021  else
2022  for (unsigned d = 0; d < dim; ++d)
2023  size *= shape[currentDim + d];
2024  newShape.push_back(size);
2025  currentDim += dim;
2026  }
2027 
2028  return RankedTensorType::get(newShape, type.getElementType());
2029 }
2030 
2031 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2032  ArrayRef<ReassociationIndices> reassociation,
2033  ArrayRef<NamedAttribute> attrs) {
2034  auto resultType = inferCollapsedType(
2035  llvm::cast<RankedTensorType>(src.getType()),
2037  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2038  result.addAttribute(getReassociationAttrStrName(),
2039  getReassociationIndicesAttribute(b, reassociation));
2040  build(b, result, resultType, src, attrs);
2041 }
2042 
2043 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2044  TensorReshapeOp, ExpandShapeOp>::value>
2045 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2046  RankedTensorType expandedType,
2047  RankedTensorType collapsedType) {
2048  if (failed(
2049  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2050  return failure();
2051 
2052  auto maps = op.getReassociationMaps();
2053  RankedTensorType expectedType =
2054  CollapseShapeOp::inferCollapsedType(expandedType, maps);
2055  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2056  return op.emitOpError("expected collapsed type to be ")
2057  << expectedType << ", but got " << collapsedType;
2058  return success();
2059 }
2060 
2061 LogicalResult ExpandShapeOp::verify() {
2062  auto srcType = getSrcType();
2063  auto resultType = getResultType();
2064 
2065  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2066  return emitOpError("expected number of static shape dims to be equal to "
2067  "the output rank (")
2068  << resultType.getRank() << ") but found "
2069  << getStaticOutputShape().size() << " inputs instead";
2070 
2071  if ((int64_t)getOutputShape().size() !=
2072  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2073  return emitOpError("mismatch in dynamic dims in output_shape and "
2074  "static_output_shape: static_output_shape has ")
2075  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2076  << " dynamic dims while output_shape has " << getOutputShape().size()
2077  << " values";
2078 
2079  return verifyTensorReshapeOp(*this, resultType, srcType);
2080 }
2081 
2082 LogicalResult CollapseShapeOp::verify() {
2083  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2084 }
2085 
2086 namespace {
2087 /// Reshape of a splat constant can be replaced with a constant of the result
2088 /// type.
2089 template <typename TensorReshapeOp>
2090 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2092  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2093  PatternRewriter &rewriter) const override {
2094  DenseElementsAttr attr;
2095  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
2096  return failure();
2097  if (!attr || !attr.isSplat())
2098  return failure();
2100  reshapeOp.getResultType(), attr.getRawData());
2101  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2102  return success();
2103  }
2104 };
2105 
2106 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2107 template <typename TensorReshapeOp>
2108 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2109 public:
2111 
2112  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2113  PatternRewriter &rewriter) const override {
2114  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2115  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2116  return failure();
2117 
2118  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2119  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2120  return success();
2121  }
2122 };
2123 
2124 /// Reshape of a FromElements can be replaced with a FromElements of the
2125 /// result type
2126 template <typename TensorReshapeOp>
2127 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2129  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2130  PatternRewriter &rewriter) const override {
2131  auto fromElements =
2132  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2133  if (!fromElements)
2134  return failure();
2135 
2136  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2137 
2138  if (!shapedTy.hasStaticShape())
2139  return failure();
2140 
2141  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2142  fromElements.getElements());
2143  return success();
2144  }
2145 };
2146 
2147 // Fold CastOp into CollapseShapeOp when adding static information.
2148 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2150 
2151  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2152  PatternRewriter &rewriter) const override {
2153  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2154  if (!tensor::canFoldIntoConsumerOp(castOp))
2155  return failure();
2156 
2157  RankedTensorType srcType =
2158  llvm::cast<RankedTensorType>(castOp.getSource().getType());
2159  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2160  srcType, collapseShapeOp.getReassociationMaps());
2161 
2162  if (newResultType == collapseShapeOp.getResultType()) {
2163  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2164  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2165  });
2166  } else {
2167  auto newOp = rewriter.create<CollapseShapeOp>(
2168  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2169  collapseShapeOp.getReassociation());
2170  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2171  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2172  }
2173  return success();
2174  }
2175 };
2176 
2177 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2178 /// matching constant output_shape operands of the expand. This makes the
2179 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2180 /// propagated further.
2181 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2183 
2184  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2185  PatternRewriter &rewriter) const override {
2186  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2187  if (!canFoldIntoConsumerOp(castOp))
2188  return failure();
2189 
2190  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2192  expandOp.getReassociationIndices();
2193 
2194  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2195  SmallVector<Value> dynamicOutputShape;
2196  auto outputIt = expandOp.getOutputShape().begin();
2197 
2198  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2199  for (uint64_t outDim : innerReassoc) {
2200  if (!ShapedType::isDynamic(newOutputShape[outDim]))
2201  continue;
2202 
2203  // If the cast's src type is dynamic, don't infer any of the
2204  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2205  // least one of the expanded dimensions to be dynamic if the input is
2206  // dynamic.
2207  Value val = *outputIt;
2208  ++outputIt;
2209  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2210  dynamicOutputShape.push_back(val);
2211  continue;
2212  }
2213 
2214  APInt cst;
2215  if (matchPattern(val, m_ConstantInt(&cst))) {
2216  newOutputShape[outDim] = cst.getSExtValue();
2217  } else {
2218  dynamicOutputShape.push_back(val);
2219  }
2220  }
2221  }
2222 
2223  // Couldn't match any values, nothing to change
2224  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2225  return failure();
2226 
2227  // Calculate the input shape from the output
2228  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2229  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2230  for (auto outDim : reassoc[inDim]) {
2231  auto ofr = newOutputShape[outDim];
2232  if (ShapedType::isDynamic(ofr)) {
2233  newInputShape[inDim] = ShapedType::kDynamic;
2234  break;
2235  }
2236  newInputShape[inDim] *= ofr;
2237  }
2238  }
2239 
2240  SmallVector<OpFoldResult> outputOfr =
2241  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2242  auto inputType = RankedTensorType::get(
2243  newInputShape, expandOp.getSrcType().getElementType());
2244  auto outputType = RankedTensorType::get(
2245  newOutputShape, expandOp.getSrcType().getElementType());
2246  auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2247  expandOp.getSrc());
2248  auto newExpand = rewriter.create<ExpandShapeOp>(
2249  expandOp.getLoc(), outputType, inputCast.getResult(),
2250  expandOp.getReassociationIndices(), outputOfr);
2251  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2252  newExpand.getResult());
2253  return success();
2254  }
2255 };
2256 } // namespace
2257 
2258 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2259  MLIRContext *context) {
2260  results.add<
2263  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2264  FoldReshapeWithSplat<ExpandShapeOp>,
2265  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2266 }
2267 
2268 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2269  MLIRContext *context) {
2270  results.add<
2272  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2273  tensor::DimOp, RankedTensorType>,
2274  FoldReshapeWithConstant<CollapseShapeOp>,
2275  FoldReshapeWithSplat<CollapseShapeOp>,
2276  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2277  context);
2278 }
2279 
2280 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2281  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2282  adaptor.getOperands());
2283 }
2284 
2285 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2286  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2287  adaptor.getOperands());
2288 }
2289 
2290 //===----------------------------------------------------------------------===//
2291 // ExtractSliceOp
2292 //===----------------------------------------------------------------------===//
2293 
2294 void ExtractSliceOp::getAsmResultNames(
2295  function_ref<void(Value, StringRef)> setNameFn) {
2296  setNameFn(getResult(), "extracted_slice");
2297 }
2298 
2299 /// An extract_slice result type can be inferred, when it is not
2300 /// rank-reduced, from the source type and the static representation of
2301 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2302 RankedTensorType ExtractSliceOp::inferResultType(
2303  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2304  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2305  // An extract_slice op may specify only a leading subset of offset/sizes/
2306  // strides in which case we complete with offset=0, sizes from memref type
2307  // and strides=1.
2308  assert(static_cast<int64_t>(staticSizes.size()) ==
2309  sourceTensorType.getRank() &&
2310  "unexpected staticSizes not equal to rank of source");
2311  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2312  sourceTensorType.getEncoding());
2313 }
2314 
2315 RankedTensorType ExtractSliceOp::inferResultType(
2316  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2318  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2319  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2320  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2321  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2322  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2323  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2324  staticSizes, staticStrides);
2325 }
2326 
2327 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2328 /// number of sizes), drop as many size 1 as needed to produce an inferred
2329 /// type with the desired rank.
2330 ///
2331 /// Note that there may be multiple ways to compute this rank-reduced type:
2332 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2333 ///
2334 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2335 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2336  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2337  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2338  ArrayRef<int64_t> strides) {
2339  // Type inferred in the absence of rank-reducing behavior.
2340  auto inferredType = llvm::cast<RankedTensorType>(
2341  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2342  int rankDiff = inferredType.getRank() - desiredResultRank;
2343  if (rankDiff > 0) {
2344  auto shape = inferredType.getShape();
2345  llvm::SmallBitVector dimsToProject =
2346  getPositionsOfShapeOne(rankDiff, shape);
2347  SmallVector<int64_t> projectedShape;
2348  // Best effort rank-reducing: drop 1s in order.
2349  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2350  if (!dimsToProject.test(pos))
2351  projectedShape.push_back(shape[pos]);
2352  inferredType =
2353  RankedTensorType::get(projectedShape, inferredType.getElementType());
2354  }
2355  return inferredType;
2356 }
2357 
2358 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2359  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2361  ArrayRef<OpFoldResult> strides) {
2362  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2363  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2364  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2365  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2366  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2367  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2368  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2369  staticStrides);
2370 }
2371 
2372 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2373 /// result type. If the type passed is nullptr, it is inferred.
2374 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2375  RankedTensorType resultType, Value source,
2376  ArrayRef<OpFoldResult> offsets,
2377  ArrayRef<OpFoldResult> sizes,
2378  ArrayRef<OpFoldResult> strides,
2379  ArrayRef<NamedAttribute> attrs) {
2380  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2381  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2382  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2383  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2384  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2385  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2386  // Structuring implementation this way avoids duplication between builders.
2387  if (!resultType) {
2388  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2389  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2390  }
2391  result.addAttributes(attrs);
2392  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2393  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2394  b.getDenseI64ArrayAttr(staticSizes),
2395  b.getDenseI64ArrayAttr(staticStrides));
2396 }
2397 
2398 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2399 /// result type.
2400 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2401  ArrayRef<OpFoldResult> offsets,
2402  ArrayRef<OpFoldResult> sizes,
2403  ArrayRef<OpFoldResult> strides,
2404  ArrayRef<NamedAttribute> attrs) {
2405  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2406 }
2407 
2408 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2409 /// a Range vector.
2410 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2411  ArrayRef<Range> ranges,
2412  ArrayRef<NamedAttribute> attrs) {
2413  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2414  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2415 }
2416 
2417 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2418 /// the type passed is nullptr, it is inferred.
2419 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2420  RankedTensorType resultType, Value source,
2421  ValueRange offsets, ValueRange sizes,
2422  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2423  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2424  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2425  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2426  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2427  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2428  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2429  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2430 }
2431 
2432 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2433 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2434  ValueRange offsets, ValueRange sizes,
2435  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2436  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2437 }
2438 
2439 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2440  Operation *op,
2441  RankedTensorType expectedType) {
2442  switch (result) {
2444  return success();
2446  return op->emitError("expected rank to be smaller or equal to ")
2447  << "the other rank. ";
2449  return op->emitError("expected type to be ")
2450  << expectedType << " or a rank-reduced version. (size mismatch) ";
2452  return op->emitError("expected element type to be ")
2453  << expectedType.getElementType();
2454  default:
2455  llvm_unreachable("unexpected extract_slice op verification result");
2456  }
2457 }
2458 
2459 /// Verifier for ExtractSliceOp.
2460 LogicalResult ExtractSliceOp::verify() {
2461  RankedTensorType sourceType = getSourceType();
2462 
2463  // Verify result type against inferred type.
2464  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2465  sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2466  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2467  if (result != SliceVerificationResult::Success)
2468  return produceSliceErrorMsg(result, *this, expectedType);
2469 
2470  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2471  // to the source tensor.
2473  sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2474  getStaticStrides(), /*generateErrorMessage=*/true);
2475  if (!boundsResult.isValid)
2476  return getOperation()->emitError(boundsResult.errorMessage);
2477 
2478  return success();
2479 }
2480 
2481 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2483 }
2484 
2485 FailureOr<Value>
2486 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2487  ArrayRef<int64_t> desiredShape) {
2488  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2489  assert(sourceTensorType && "not a ranked tensor type");
2490  auto sourceShape = sourceTensorType.getShape();
2491  if (sourceShape.equals(desiredShape))
2492  return value;
2493  auto maybeRankReductionMask =
2494  mlir::computeRankReductionMask(sourceShape, desiredShape);
2495  if (!maybeRankReductionMask)
2496  return failure();
2498  b, loc, value,
2499  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2500 }
2501 
2502 LogicalResult ExtractSliceOp::reifyResultShapes(
2503  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2504  reifiedReturnShapes.resize(1);
2505  reifiedReturnShapes[0].reserve(getType().getRank());
2507  llvm::SmallBitVector droppedDims = getDroppedDims();
2508  for (const auto &size : enumerate(mixedSizes)) {
2509  if (droppedDims.test(size.index()))
2510  continue;
2511  reifiedReturnShapes[0].push_back(size.value());
2512  }
2513  return success();
2514 }
2515 
2516 namespace {
2517 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2518 /// This essentially pushes memref_cast past its consuming slice when
2519 /// `canFoldIntoConsumerOp` is true.
2520 ///
2521 /// Example:
2522 /// ```
2523 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2524 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2525 /// tensor<3x4xf32>
2526 /// ```
2527 /// is rewritten into:
2528 /// ```
2529 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2530 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2531 /// ```
2532 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2533 public:
2535 
2536  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2537  PatternRewriter &rewriter) const override {
2538  // Any constant operand, just return to let the constant folder kick in.
2539  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2540  return matchPattern(operand, matchConstantIndex());
2541  }))
2542  return failure();
2543 
2544  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2545  if (!castOp)
2546  return failure();
2547 
2548  if (!canFoldIntoConsumerOp(castOp))
2549  return failure();
2550 
2551  // Pattern does not apply if the produced op would not verify.
2553  cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2554  sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2555  sliceOp.getStaticStrides());
2556  if (!sliceResult.isValid)
2557  return failure();
2558 
2559  // Create folded extract.
2560  Location loc = sliceOp.getLoc();
2561  Value newResult = rewriter.create<ExtractSliceOp>(
2562  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2563  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2564  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2565  rewriter.replaceOp(sliceOp, newResult);
2566  return success();
2567  }
2568 };
2569 
2570 /// Slice elements from `values` into `outValues`. `counts` represents the
2571 /// numbers of elements to stride in the original values for each dimension.
2572 /// The output values can be used to construct a DenseElementsAttr.
2573 template <typename IterTy, typename ElemTy>
2574 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2575  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2576  ArrayRef<int64_t> strides,
2577  llvm::SmallVectorImpl<ElemTy> *outValues) {
2578  assert(offsets.size() == sizes.size());
2579  assert(offsets.size() == strides.size());
2580  if (offsets.empty())
2581  return;
2582 
2583  int64_t offset = offsets.front();
2584  int64_t size = sizes.front();
2585  int64_t stride = strides.front();
2586  if (offsets.size() == 1) {
2587  for (int64_t i = 0; i < size; ++i, offset += stride)
2588  outValues->push_back(*(values + offset));
2589 
2590  return;
2591  }
2592 
2593  for (int64_t i = 0; i < size; ++i, offset += stride) {
2594  auto begin = values + offset * counts.front();
2595  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2596  offsets.drop_front(), sizes.drop_front(),
2597  strides.drop_front(), outValues);
2598  }
2599 }
2600 
2601 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2602 /// folded operation might introduce more constant data; Users can control
2603 /// their heuristics by the control function.
2604 class ConstantOpExtractSliceFolder final
2605  : public OpRewritePattern<ExtractSliceOp> {
2606 public:
2608 
2609  ConstantOpExtractSliceFolder(MLIRContext *context,
2611  : OpRewritePattern<ExtractSliceOp>(context),
2612  controlFn(std::move(controlFn)) {}
2613 
2614  LogicalResult matchAndRewrite(ExtractSliceOp op,
2615  PatternRewriter &rewriter) const override {
2616  DenseElementsAttr attr;
2617  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2618  return failure();
2619 
2620  // A constant splat is handled by fold().
2621  if (attr.isSplat())
2622  return failure();
2623 
2624  // Dynamic result shape is not supported.
2625  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2626  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2627  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2628  return failure();
2629 
2630  // Customized control over the folding.
2631  if (!controlFn(op))
2632  return failure();
2633 
2634  int64_t count = sourceType.getNumElements();
2635  if (count == 0)
2636  return failure();
2637 
2638  // Check if there are any dynamic parts, which are not supported.
2639  auto offsets = op.getStaticOffsets();
2640  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2641  return failure();
2642  auto sizes = op.getStaticSizes();
2643  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2644  return failure();
2645  auto strides = op.getStaticStrides();
2646  if (llvm::is_contained(strides, ShapedType::kDynamic))
2647  return failure();
2648 
2649  // Compute the stride for each dimension.
2650  SmallVector<int64_t> counts;
2651  ArrayRef<int64_t> shape = sourceType.getShape();
2652  counts.reserve(shape.size());
2653  for (int64_t v : shape) {
2654  count = count / v;
2655  counts.push_back(count);
2656  }
2657 
2658  // New attribute constructed by the sliced values.
2659  DenseElementsAttr newAttr;
2660 
2661  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2662  SmallVector<APInt> outValues;
2663  outValues.reserve(sourceType.getNumElements());
2664  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2665  elems.begin(), counts, offsets, sizes, strides, &outValues);
2666  newAttr = DenseElementsAttr::get(resultType, outValues);
2667  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2668  SmallVector<APFloat> outValues;
2669  outValues.reserve(sourceType.getNumElements());
2670  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2671  elems.begin(), counts, offsets, sizes, strides, &outValues);
2672  newAttr = DenseElementsAttr::get(resultType, outValues);
2673  }
2674 
2675  if (newAttr) {
2676  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2677  return success();
2678  }
2679 
2680  return failure();
2681  }
2682 
2683 private:
2684  /// This additionally controls whether the fold happens or not. Users can
2685  /// impose their heuristics in the function.
2687 };
2688 
2689 } // namespace
2690 
2693  const ControlConstantExtractSliceFusionFn &controlFn) {
2694  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2695 }
2696 
2697 /// Return the canonical type of the result of an extract_slice op.
2699  RankedTensorType operator()(ExtractSliceOp op,
2700  ArrayRef<OpFoldResult> mixedOffsets,
2701  ArrayRef<OpFoldResult> mixedSizes,
2702  ArrayRef<OpFoldResult> mixedStrides) {
2703  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2704  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2705  mixedStrides);
2706  }
2707 };
2708 
2709 /// A canonicalizer wrapper to replace ExtractSliceOps.
2711  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2712  ExtractSliceOp newOp) {
2713  Value replacement = newOp.getResult();
2714  if (replacement.getType() != op.getType())
2715  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2716  replacement);
2717  rewriter.replaceOp(op, replacement);
2718  }
2719 };
2720 
2721 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2722  MLIRContext *context) {
2723  results.add<
2726  ExtractSliceOpCastFolder>(context);
2727 }
2728 
2729 //
2730 static LogicalResult
2731 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2732  ShapedType shapedType) {
2733  OpBuilder b(op.getContext());
2734  for (OpFoldResult ofr : op.getMixedOffsets())
2735  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2736  return failure();
2737  // Rank-reducing noops only need to inspect the leading dimensions:
2738  // llvm::zip is appropriate.
2739  auto shape = shapedType.getShape();
2740  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2741  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2742  return failure();
2743  for (OpFoldResult ofr : op.getMixedStrides())
2744  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2745  return failure();
2746  return success();
2747 }
2748 
2749 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2750 /// slice, we can return the InsertSliceOp's source directly.
2751 // TODO: This only checks the immediate producer; extend to go up the
2752 // insert/extract chain if the slices are disjoint.
2753 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2754  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2755 
2756  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2757  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2758  insertOp.isSameAs(extractOp, isSame))
2759  return insertOp.getSource();
2760 
2761  return {};
2762 }
2763 
2764 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2765  if (OpFoldResult reshapedSource = reshapeConstantSource(
2766  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2767  getResult().getType()))
2768  return reshapedSource;
2769  if (getSourceType() == getType() &&
2771  return this->getSource();
2772  if (Value slice = foldExtractAfterInsertSlice(*this))
2773  return slice;
2774 
2775  return OpFoldResult();
2776 }
2777 
2779  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2780  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2781  unsigned rank = rankedTensorType.getRank();
2782  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2783  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2784  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2785  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2786  offsets, sizes, strides);
2787 }
2788 
2789 //===----------------------------------------------------------------------===//
2790 // InsertSliceOp
2791 //===----------------------------------------------------------------------===//
2792 
2793 void InsertSliceOp::getAsmResultNames(
2794  function_ref<void(Value, StringRef)> setNameFn) {
2795  setNameFn(getResult(), "inserted_slice");
2796 }
2797 
2798 // Build a InsertSliceOp with mixed static and dynamic entries.
2799 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2800  Value dest, ArrayRef<OpFoldResult> offsets,
2801  ArrayRef<OpFoldResult> sizes,
2802  ArrayRef<OpFoldResult> strides,
2803  ArrayRef<NamedAttribute> attrs) {
2804  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2805  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2806  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2807  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2808  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2809  result.addAttributes(attrs);
2810  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2811  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2812  b.getDenseI64ArrayAttr(staticSizes),
2813  b.getDenseI64ArrayAttr(staticStrides));
2814 }
2815 
2816 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2817 /// Range vector.
2818 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2819  Value dest, ArrayRef<Range> ranges,
2820  ArrayRef<NamedAttribute> attrs) {
2821  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2822  build(b, result, source, dest, offsets, sizes, strides, attrs);
2823 }
2824 
2825 // Build a InsertSliceOp with dynamic entries.
2826 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2827  Value dest, ValueRange offsets, ValueRange sizes,
2828  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2829  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2830  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2831  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2832  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2833  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2834  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2835  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2836 }
2837 
2838 /// Rank-reducing type verification for both InsertSliceOp and
2839 /// ParallelInsertSliceOp.
2841  RankedTensorType srcType, RankedTensorType dstType,
2842  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2843  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2844  // insert_slice is the inverse of extract_slice, use the same type
2845  // inference.
2846  RankedTensorType expected = ExtractSliceOp::inferResultType(
2847  dstType, staticOffsets, staticSizes, staticStrides);
2848  if (expectedType)
2849  *expectedType = expected;
2850  return isRankReducedType(expected, srcType);
2851 }
2852 
2853 /// Verifier for InsertSliceOp.
2854 LogicalResult InsertSliceOp::verify() {
2855  // Verify result type against inferred type.
2856  RankedTensorType expectedType;
2857  SliceVerificationResult result =
2858  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2859  getStaticSizes(), getStaticStrides(), &expectedType);
2860  if (result != SliceVerificationResult::Success)
2861  return produceSliceErrorMsg(result, *this, expectedType);
2862 
2863  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2864  // to the destination tensor.
2866  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2867  getStaticStrides(), /*generateErrorMessage=*/true);
2868  if (!boundsResult.isValid)
2869  return getOperation()->emitError(boundsResult.errorMessage);
2870 
2871  return success();
2872 }
2873 
2874 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2875 /// can mutate the second InsertSliceOp's destination to the first one's.
2876 ///
2877 /// Example:
2878 ///
2879 /// ```mlir
2880 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2881 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2882 /// ```
2883 ///
2884 /// folds into:
2885 ///
2886 /// ```mlir
2887 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2888 /// ```
2889 ///
2890 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2891 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2892  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2893 
2894  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2895  if (!prevInsertOp ||
2896  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2897  !prevInsertOp.isSameAs(insertOp, isSame))
2898  return failure();
2899 
2900  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2901  return success();
2902 }
2903 
2904 /// Folds round-trip extract/insert slice op pairs.
2905 /// Example:
2906 /// ```mlir
2907 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2908 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2909 /// ```
2910 /// can be folded into %val.
2911 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2912  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2913 
2914  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2915  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2916  !extractOp.isSameAs(insertOp, isSame))
2917  return nullptr;
2918 
2919  return extractOp.getSource();
2920 }
2921 
2922 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2923  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2924  getSourceType() == getType() &&
2926  return this->getSource();
2927  if (succeeded(foldInsertAfterInsertSlice(*this)))
2928  return getResult();
2929  if (auto result = foldInsertAfterExtractSlice(*this))
2930  return result;
2931  if (llvm::any_of(getMixedSizes(), isZeroInteger))
2932  return getDest();
2933  return OpFoldResult();
2934 }
2935 
2936 LogicalResult InsertSliceOp::reifyResultShapes(
2937  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2938  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2939  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2940  return success();
2941 }
2942 
2943 namespace {
2944 /// Pattern to rewrite a insert_slice op with constant arguments.
2945 ///
2946 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2947 template <typename InsertOpTy>
2948 class InsertSliceOpConstantArgumentFolder final
2949  : public OpRewritePattern<InsertOpTy> {
2950 public:
2952 
2953  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2954  PatternRewriter &rewriter) const override {
2955  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2956  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2957  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2958 
2959  // No constant operands were folded, just return;
2960  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2961  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2962  failed(foldDynamicStrideList(mixedStrides)))
2963  return failure();
2964 
2965  // Pattern does not apply if the produced op would not verify.
2966  SliceBoundsVerificationResult sliceResult =
2967  verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2968  mixedOffsets, mixedSizes, mixedStrides);
2969  if (!sliceResult.isValid)
2970  return failure();
2971 
2972  // Create the new op in canonical form.
2973  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2974  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2975  mixedOffsets, mixedSizes, mixedStrides);
2976  Value toInsert = insertSliceOp.getSource();
2977  if (sourceType != insertSliceOp.getSourceType()) {
2978  OpBuilder::InsertionGuard g(rewriter);
2979  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2980  // is that the insertion point is just before the ParallelCombiningOp in
2981  // the parallel case.
2982  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2983  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2984  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2985  sourceType, toInsert);
2986  }
2987  rewriter.replaceOpWithNewOp<InsertOpTy>(
2988  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2989  mixedSizes, mixedStrides);
2990  return success();
2991  }
2992 };
2993 
2994 /// Fold tensor_casts with insert_slice operations. If the source or
2995 /// destination tensor is a tensor_cast that removes static type information,
2996 /// the cast is folded into the insert_slice operation. E.g.:
2997 ///
2998 /// ```mlir
2999 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3000 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3001 /// ```
3002 ///
3003 /// folds into:
3004 ///
3005 /// ```mlir
3006 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3007 /// ```
3008 ///
3009 /// Note: When folding a cast on the destination tensor, the result of the
3010 /// insert_slice operation is casted to ensure that the type of the result did
3011 /// not change.
3012 ///
3013 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3014 template <typename InsertOpTy>
3015 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3017 
3018  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3019  PatternRewriter &rewriter) const override {
3020  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3021  return matchPattern(operand, matchConstantIndex());
3022  }))
3023  return failure();
3024 
3025  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3026  auto castOp = v.getDefiningOp<tensor::CastOp>();
3027  if (!castOp || !canFoldIntoConsumerOp(castOp))
3028  return std::nullopt;
3029  return castOp.getSource();
3030  };
3031  std::optional<Value> sourceCastSource =
3032  getSourceOfCastOp(insertSliceOp.getSource());
3033  std::optional<Value> destCastSource =
3034  getSourceOfCastOp(insertSliceOp.getDest());
3035  if (!sourceCastSource && !destCastSource)
3036  return failure();
3037 
3038  auto src =
3039  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3040  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3041  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3042  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3043  if (!srcType || !dstType)
3044  return failure();
3045 
3046  // The tensor.cast source could have additional static information not seen
3047  // in the insert slice op static sizes, so we ignore dynamic dims when
3048  // computing the rank reduction mask.
3049  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3050  auto rankReductionMask = computeRankReductionMask(
3051  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3052  if (!rankReductionMask.has_value())
3053  return failure();
3054  // Replace dimensions in the insert slice op with corresponding static dims
3055  // from the cast source type. If the insert slice sizes have static dims
3056  // that are not static in the tensor.cast source (i.e., when the cast op
3057  // casts a dynamic dim to static), the dim should not be replaced, and the
3058  // pattern will fail later in `verifyInsertSliceOp`.
3059  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3060  int64_t rankReducedIdx = 0;
3061  for (auto [idx, size] : enumerate(staticSizes)) {
3062  if (!rankReductionMask.value().contains(idx) &&
3063  !srcType.isDynamicDim(rankReducedIdx)) {
3064  mixedSizes[idx] = getAsIndexOpFoldResult(
3065  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3066  size = srcType.getDimSize(rankReducedIdx++);
3067  }
3068  }
3069 
3070  // Pattern does not apply if the produced op would not verify.
3071  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3072  staticSizes, insertSliceOp.getStaticStrides()) !=
3074  return failure();
3075  SliceBoundsVerificationResult sliceResult =
3076  verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3077  mixedSizes, insertSliceOp.getMixedStrides());
3078  if (!sliceResult.isValid)
3079  return failure();
3080 
3081  Operation *replacement = rewriter.create<InsertOpTy>(
3082  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
3083  mixedSizes, insertSliceOp.getMixedStrides());
3084 
3085  // In the parallel case there is no result and so nothing to cast.
3086  bool isParallelInsert =
3087  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
3090  insertSliceOp.getDestType(),
3091  replacement->getResult(0));
3092  }
3093  rewriter.replaceOp(insertSliceOp, replacement->getResults());
3094  return success();
3095  }
3096 };
3097 
3098 /// If additional static type information can be deduced from a insert_slice's
3099 /// size operands, insert an explicit cast of the op's source operand. This
3100 /// enables other canonicalization patterns that are matching for tensor_cast
3101 /// ops such as `ForOpTensorCastFolder` in SCF.
3102 ///
3103 /// Example:
3104 ///
3105 /// ```mlir
3106 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3107 /// : tensor<?x?xf32> into ...
3108 /// ```
3109 ///
3110 /// folds into:
3111 ///
3112 /// ```mlir
3113 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3114 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3115 /// : tensor<64x64xf32> into ...
3116 /// ```
3117 ///
3118 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3119 template <typename InsertOpTy>
3120 struct InsertSliceOpSourceCastInserter final
3121  : public OpRewritePattern<InsertOpTy> {
3123 
3124  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3125  PatternRewriter &rewriter) const override {
3126  RankedTensorType srcType = insertSliceOp.getSourceType();
3127  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3128  return failure();
3129  SmallVector<int64_t> newSrcShape(srcType.getShape());
3130  for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131  if (std::optional<int64_t> constInt =
3132  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3133  // Bail on invalid IR.
3134  if (*constInt < 0)
3135  return failure();
3136  newSrcShape[i] = *constInt;
3137  }
3138  }
3139  if (!hasValidSizesOffsets(newSrcShape))
3140  return failure();
3141 
3142  RankedTensorType newSrcType = RankedTensorType::get(
3143  newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144  if (srcType == newSrcType ||
3145  !preservesStaticInformation(srcType, newSrcType) ||
3146  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3147  return failure();
3148 
3149  // newSrcType is:
3150  // 1) Different from srcType.
3151  // 2) "More static" than srcType.
3152  // 3) Cast-compatible with srcType.
3153  // Insert the cast.
3154  OpBuilder::InsertionGuard g(rewriter);
3155  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3156  // that the insertion point is just before the ParallelCombiningOp in the
3157  // parallel case.
3158  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3159  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3160  Value cast = rewriter.create<tensor::CastOp>(
3161  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3162  rewriter.replaceOpWithNewOp<InsertOpTy>(
3163  insertSliceOp, cast, insertSliceOp.getDest(),
3164  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165  insertSliceOp.getMixedStrides());
3166  return success();
3167  }
3168 };
3169 } // namespace
3170 
3171 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3172  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3173 }
3174 
3175 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3176  MLIRContext *context) {
3177  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3178  InsertSliceOpCastFolder<InsertSliceOp>,
3179  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3180 }
3181 
3183  Location loc,
3184  Value tensor,
3185  Value dest) {
3186  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3187  unsigned rank = rankedTensorType.getRank();
3188  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3189  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3190  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3191  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3192  sizes, strides);
3193 }
3194 
3195 //===----------------------------------------------------------------------===//
3196 // PadOp
3197 //===----------------------------------------------------------------------===//
3198 
3199 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3200  setNameFn(getResult(), "padded");
3201 }
3202 
3203 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3204 // supports optional types.
3205 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3206  Type typeToInfer, Type typeToInferFrom) {}
3207 
3208 ParseResult
3210  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3211  Type &typeToInfer, Type typeToInferFrom) {
3212  if (optOperand)
3213  typeToInfer = typeToInferFrom;
3214  return success();
3215 }
3216 
3217 LogicalResult PadOp::verify() {
3218  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3219  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3220  auto expectedType =
3221  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3222  if (!expectedType) {
3223  return emitError("failed to infer expectedType from sourceType ")
3224  << sourceType << ", specified resultType is " << resultType;
3225  }
3226  if (resultType.getRank() != expectedType.getRank()) {
3227  return emitError("specified type ")
3228  << resultType << " does not match the inferred type "
3229  << expectedType;
3230  }
3231  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3232  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3233  continue;
3234  if (expectedType.isDynamicDim(i))
3235  continue;
3236  return emitError("specified type ")
3237  << resultType << " does not match the inferred type "
3238  << expectedType;
3239  }
3240 
3241  return success();
3242 }
3243 
3244 LogicalResult PadOp::verifyRegions() {
3245  auto &region = getRegion();
3246  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3247  Block &block = region.front();
3248  if (block.getNumArguments() != rank)
3249  return emitError("expected the block to have ") << rank << " arguments";
3250 
3251  // Note: the number and type of yield values are checked in the YieldOp.
3252  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3253  if (!en.value().isIndex())
3254  return emitOpError("expected block argument ")
3255  << (en.index() + 1) << " to be an index";
3256  }
3257 
3258  // Ensure that the region yields an element of the right type.
3259  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3260  if (yieldOp.getValue().getType() !=
3261  llvm::cast<ShapedType>(getType()).getElementType())
3262  return emitOpError("expected yield type to match shape element type");
3263 
3264  return success();
3265 }
3266 
3267 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3268  ArrayRef<int64_t> staticLow,
3269  ArrayRef<int64_t> staticHigh,
3270  ArrayRef<int64_t> resultShape) {
3271  unsigned rank = sourceType.getRank();
3272  if (staticLow.size() != rank)
3273  return RankedTensorType();
3274  if (staticHigh.size() != rank)
3275  return RankedTensorType();
3276  if (!resultShape.empty() && resultShape.size() != rank)
3277  return RankedTensorType();
3278 
3279  SmallVector<int64_t, 4> inferredShape;
3280  for (auto i : llvm::seq<unsigned>(0, rank)) {
3281  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3282  staticHigh[i] == ShapedType::kDynamic) {
3283  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3284  : resultShape[i]);
3285  } else {
3286  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3287  assert((resultShape.empty() || size == resultShape[i] ||
3288  resultShape[i] == ShapedType::kDynamic) &&
3289  "mismatch between inferred shape and result shape");
3290  inferredShape.push_back(size);
3291  }
3292  }
3293 
3294  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3295 }
3296 
3297 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3298  Value source, ArrayRef<int64_t> staticLow,
3299  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3300  bool nofold, ArrayRef<NamedAttribute> attrs) {
3301  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3302  if (!resultType)
3303  resultType = inferResultType(sourceType, staticLow, staticHigh);
3304  result.addAttributes(attrs);
3305  build(b, result, resultType, source, low, high,
3306  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3307  nofold ? b.getUnitAttr() : UnitAttr());
3308 }
3309 
3310 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3311  Value source, ValueRange low, ValueRange high, bool nofold,
3312  ArrayRef<NamedAttribute> attrs) {
3313  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3314  unsigned rank = sourceType.getRank();
3315  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3316  build(b, result, resultType, source, staticVector, staticVector, low, high,
3317  nofold, attrs);
3318 }
3319 
3320 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3321  Value source, ArrayRef<OpFoldResult> low,
3322  ArrayRef<OpFoldResult> high, bool nofold,
3323  ArrayRef<NamedAttribute> attrs) {
3324  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3325  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3326  SmallVector<int64_t, 4> staticLow, staticHigh;
3327  // staticLow and staticHigh have full information of the padding config.
3328  // This will grow staticLow and staticHigh with 1 value. If the config is
3329  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3330  // value as well.
3331  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3332  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3333  if (!resultType) {
3334  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3335  }
3336  assert(llvm::isa<RankedTensorType>(resultType));
3337  result.addAttributes(attrs);
3338  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3339  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3340  nofold ? b.getUnitAttr() : UnitAttr());
3341 }
3342 
3343 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3344  Value source, ArrayRef<OpFoldResult> low,
3345  ArrayRef<OpFoldResult> high, Value constantPadValue,
3346  bool nofold, ArrayRef<NamedAttribute> attrs) {
3347  build(b, result, resultType, source, low, high, nofold, attrs);
3348 
3349  // Add a region and a block to yield the pad value.
3350  Region *region = result.regions[0].get();
3351  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3352  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3353  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3354 
3355  // `builder.createBlock` changes the insertion point within the block. Create
3356  // a guard to reset the insertion point of the builder after it is destroyed.
3357  OpBuilder::InsertionGuard guard(b);
3358  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3359  b.create<tensor::YieldOp>(result.location, constantPadValue);
3360 }
3361 
3362 llvm::SmallBitVector PadOp::getPaddedDims() {
3363  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3364  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3365  for (const auto &en : enumerate(paddingWidths))
3366  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3367  paddedDims.set(en.index());
3368  };
3369  extractPaddedDims(getMixedLowPad());
3370  extractPaddedDims(getMixedHighPad());
3371  return paddedDims;
3372 }
3373 
3374 namespace {
3375 // Folds tensor.pad when padding is static zeros and the attribute
3376 // doesn't request otherwise.
3377 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3379 
3380  LogicalResult matchAndRewrite(PadOp padTensorOp,
3381  PatternRewriter &rewriter) const override {
3382  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3383  return failure();
3384  if (padTensorOp.getNofold())
3385  return failure();
3386  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3387  padTensorOp, padTensorOp.getResult().getType(),
3388  padTensorOp.getSource());
3389  return success();
3390  }
3391 };
3392 
3393 // Fold CastOp into PadOp when adding static information.
3394 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3396 
3397  LogicalResult matchAndRewrite(PadOp padTensorOp,
3398  PatternRewriter &rewriter) const override {
3399  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3400  if (!tensor::canFoldIntoConsumerOp(castOp))
3401  return failure();
3402 
3403  auto newResultType = PadOp::inferResultType(
3404  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3405  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3406  padTensorOp.getResultType().getShape());
3407 
3408  if (newResultType == padTensorOp.getResultType()) {
3409  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3410  padTensorOp.getSourceMutable().assign(castOp.getSource());
3411  });
3412  } else {
3413  auto newOp = rewriter.create<PadOp>(
3414  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3415  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3416  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3417  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3418  IRMapping mapper;
3419  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3420 
3421  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3422  padTensorOp, padTensorOp.getResultType(), newOp);
3423  }
3424  return success();
3425  }
3426 };
3427 
3428 // Fold CastOp using the result of PadOp back into the latter if it adds
3429 // static information.
3430 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3432 
3433  LogicalResult matchAndRewrite(PadOp padTensorOp,
3434  PatternRewriter &rewriter) const override {
3435  if (!padTensorOp.getResult().hasOneUse())
3436  return failure();
3437  auto tensorCastOp =
3438  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3439  if (!tensorCastOp)
3440  return failure();
3441  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3442  tensorCastOp.getDest().getType()))
3443  return failure();
3444 
3445  auto replacementOp = rewriter.create<PadOp>(
3446  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3447  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449  padTensorOp.getHigh(), padTensorOp.getNofold(),
3450  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3451  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3452 
3453  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3454  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3455  return success();
3456  }
3457 };
3458 
3459 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3460 /// different dimensions. The pattern applies if the following preconditions
3461 /// hold:
3462 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3463 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3464 /// 3) the tensor::PadOps perform only high-padding,
3465 /// 4) the tensor::PadOps have the same constant padding value,
3466 /// 5) the tensor::PadOps do not have common padding dimensions,
3467 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3468 /// zero-offset for every dimension.
3469 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3470 /// the
3471 /// padded source dimensions.
3472 ///
3473 /// Example:
3474 ///
3475 /// ```mlir
3476 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3477 /// : tensor<64x64xf32> to tensor<?x64xf32>
3478 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3479 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3480 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3481 /// : tensor<8x64xf32> to tensor<8x?xf32>
3482 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3483 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3484 /// ```
3485 ///
3486 /// folds into:
3487 ///
3488 /// ```mlir
3489 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3490 /// : tensor<64x64xf32> to tensor<?x?xf32>
3491 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3492 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3493 /// ```
3494 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3496 
3497  LogicalResult matchAndRewrite(PadOp padOp,
3498  PatternRewriter &rewriter) const override {
3499  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3500  if (!innerSliceOp)
3501  return failure();
3502  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3503  if (!outerPadOp || outerPadOp.getNofold())
3504  return failure();
3505  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3506  if (!outerSliceOp)
3507  return failure();
3508 
3509  // 1) Fail if the chain is rank-reducing.
3510  int64_t rank = padOp.getSourceType().getRank();
3511  if (outerSliceOp.getSourceType().getRank() != rank) {
3512  return rewriter.notifyMatchFailure(padOp,
3513  "cannot fold rank-reducing chain");
3514  }
3515 
3516  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3517  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3518  return rewriter.notifyMatchFailure(
3519  padOp, "cannot fold non-unit stride ExtractSliceOps");
3520  }
3521 
3522  // 3) Fail if the tensor::PadOps have non-zero low padding.
3523  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3524  return rewriter.notifyMatchFailure(padOp,
3525  "cannot fold PadOps with low padding");
3526  }
3527 
3528  // 4) Fail if the tensor::PadOps padding values do not match.
3529  Attribute innerAttr, outerAttr;
3530  Value innerValue = padOp.getConstantPaddingValue();
3531  Value outerValue = outerPadOp.getConstantPaddingValue();
3532  if (!innerValue || !outerValue ||
3533  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3534  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3535  innerAttr != outerAttr) {
3536  return rewriter.notifyMatchFailure(
3537  padOp, "cannot fold PadOps with different padding values");
3538  }
3539 
3540  // 5) Fail if a dimension is padded by both tensor::PadOps.
3541  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3542  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3543  if (innerDims.anyCommon(outerDims)) {
3544  return rewriter.notifyMatchFailure(
3545  padOp, "cannot fold PadOps with common padding dimensions");
3546  }
3547 
3548  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3549  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3550  // for every dimension, and use the offset the other pair. Fail if no
3551  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3552  // exists.
3553  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3554  for (auto en : enumerate(newOffsets)) {
3555  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3556  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3557  if (!innerDims.test(en.index()) &&
3558  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3559  en.value() = outerOffset;
3560  continue;
3561  }
3562  if (!outerDims.test(en.index()) &&
3563  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3564  en.value() = innerOffset;
3565  continue;
3566  }
3567  return rewriter.notifyMatchFailure(
3568  padOp, "cannot find zero-offset and zero-padding pair");
3569  }
3570 
3571  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3572  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3573  // outer tensor::PadOp and fail if the size of the inner
3574  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3575  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3576  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3577  for (auto en : enumerate(newSizes)) {
3578  if (!outerDims.test(en.index()))
3579  continue;
3580  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3581  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3582  assert(!ShapedType::isDynamic(sourceSize) &&
3583  "expected padded dimension to have a static size");
3584  if (getConstantIntValue(sliceSize) != sourceSize) {
3585  return rewriter.notifyMatchFailure(
3586  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3587  "match the size of the outer padding");
3588  }
3589  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3590  }
3591 
3592  // Combine the high paddings of the two tensor::PadOps.
3593  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3594  for (auto en : enumerate(newHighPad)) {
3595  if (innerDims.test(en.index()))
3596  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3597  if (outerDims.test(en.index()))
3598  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3599  }
3600 
3601  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3602  // the two paddings in one step.
3603  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3604  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3605  innerSliceOp.getMixedStrides());
3606  auto newPadOp = rewriter.create<PadOp>(
3607  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3608  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3609  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3610  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3611  newPadOp.getRegion().begin());
3612  rewriter.replaceOp(padOp, newPadOp.getResult());
3613  return success();
3614  }
3615 };
3616 
3617 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3619 
3620  LogicalResult matchAndRewrite(PadOp padTensorOp,
3621  PatternRewriter &rewriter) const override {
3622  Value input = padTensorOp.getSource();
3623  if (!llvm::isa<RankedTensorType>(input.getType()))
3624  return failure();
3625  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3626  auto inputRank = inputDims.size();
3627 
3628  auto oldResultType =
3629  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3630  if (!oldResultType)
3631  return failure();
3632 
3633  auto outputDims = oldResultType.getShape();
3634 
3635  // Extract the static info from the high and low operands.
3636  SmallVector<int64_t> constOperandsLow;
3637  SmallVector<Value> newLows;
3638  for (auto operand : padTensorOp.getLow()) {
3639  APSInt intOp;
3640  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3641  constOperandsLow.push_back(ShapedType::kDynamic);
3642  newLows.push_back(operand);
3643  continue;
3644  }
3645  constOperandsLow.push_back(intOp.getExtValue());
3646  }
3647  SmallVector<int64_t> constOperandsHigh;
3648  SmallVector<Value> newHighs;
3649  for (auto operand : padTensorOp.getHigh()) {
3650  APSInt intOp;
3651  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3652  constOperandsHigh.push_back(ShapedType::kDynamic);
3653  newHighs.push_back(operand);
3654  continue;
3655  }
3656  constOperandsHigh.push_back(intOp.getExtValue());
3657  }
3658 
3659  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3660  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3661 
3662  // Verify the op is well-formed.
3663  if (inputDims.size() != outputDims.size() ||
3664  inputDims.size() != constLow.size() ||
3665  inputDims.size() != constHigh.size())
3666  return failure();
3667 
3668  auto lowCount = 0;
3669  auto highCount = 0;
3670  for (size_t i = 0; i < inputRank; i++) {
3671  if (constLow[i] == ShapedType::kDynamic)
3672  constLow[i] = constOperandsLow[lowCount++];
3673  if (constHigh[i] == ShapedType::kDynamic)
3674  constHigh[i] = constOperandsHigh[highCount++];
3675  }
3676 
3677  auto staticLow = ArrayRef<int64_t>(constLow);
3678  auto staticHigh = ArrayRef<int64_t>(constHigh);
3679 
3680  // Calculate the output sizes with the static information.
3681  SmallVector<int64_t> newOutDims;
3682  for (size_t i = 0; i < inputRank; i++) {
3683  if (outputDims[i] == ShapedType::kDynamic) {
3684  newOutDims.push_back(
3685  (staticLow[i] == ShapedType::kDynamic ||
3686  staticHigh[i] == ShapedType::kDynamic ||
3687  inputDims[i] == ShapedType::kDynamic
3688  ? ShapedType::kDynamic
3689  : inputDims[i] + staticLow[i] + staticHigh[i]));
3690  } else {
3691  newOutDims.push_back(outputDims[i]);
3692  }
3693  }
3694 
3695  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3696  llvm::all_of(newOutDims,
3697  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3698  return failure();
3699 
3700  // Rewrite the op using the new static type.
3701  auto newResultType = RankedTensorType::get(
3702  newOutDims, padTensorOp.getType().getElementType());
3703  auto newOp = rewriter.create<PadOp>(
3704  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3705  newLows, newHighs, padTensorOp.getNofold(),
3706  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3707 
3708  IRMapping mapper;
3709  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3710  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3711  newOp);
3712 
3713  return success();
3714  }
3715 };
3716 
3717 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3718 ///
3719 /// Example:
3720 ///
3721 /// ```mlir
3722 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3723 /// tensor.yield %val
3724 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3725 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3726 /// tensor.yield %val
3727 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3728 /// ```
3729 ///
3730 /// folds into:
3731 ///
3732 /// ```mlir
3733 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3734 /// tensor.yield %val
3735 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3736 /// ```
3737 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3739 
3740  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3741  PatternRewriter &rewriter) const override {
3742  if (padOp.getNofold()) {
3743  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3744  }
3745 
3746  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3747  if (!producerPad || producerPad.getNofold()) {
3748  return rewriter.notifyMatchFailure(
3749  padOp, "producer is not a foldable tensor.pad op");
3750  }
3751 
3752  // Fail if the tensor::PadOps padding values do not match.
3753  Value consumerPadValue = padOp.getConstantPaddingValue();
3754  Value producerPadValue = producerPad.getConstantPaddingValue();
3755  if (!consumerPadValue || !producerPadValue ||
3756  consumerPadValue != producerPadValue) {
3757  return rewriter.notifyMatchFailure(
3758  padOp,
3759  "cannot fold PadOps with different or non-constant padding values");
3760  }
3761 
3762  Location loc = padOp.getLoc();
3763  AffineExpr d0, d1;
3764  bindDims(rewriter.getContext(), d0, d1);
3765 
3766  // Combine the low/high paddings of the two tensor::PadOps.
3767  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3768  ArrayRef<OpFoldResult> producerPaddings) {
3769  SmallVector<OpFoldResult> sumPaddings;
3770  for (auto [consumerIndex, producerIndex] :
3771  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3772  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3773  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3774  }
3775  return sumPaddings;
3776  };
3777 
3778  SmallVector<OpFoldResult> newHighPad =
3779  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3780  SmallVector<OpFoldResult> newLowPad =
3781  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3782 
3783  auto newPadOp = rewriter.create<tensor::PadOp>(
3784  padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3785  newLowPad, newHighPad, padOp.getNofold(),
3786  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3787  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3788  newPadOp.getRegion().begin());
3789  rewriter.replaceOp(padOp, newPadOp.getResult());
3790  return success();
3791  }
3792 };
3793 
3794 } // namespace
3795 
3796 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3797  MLIRContext *context) {
3798  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3799  FoldOrthogonalPaddings, FoldStaticPadding,
3800  FoldConsecutiveConstantPadding>(context);
3801 }
3802 
3803 /// Return the padding value of the PadOp if it constant. In this context,
3804 /// "constant" means an actual constant or "defined outside of the block".
3805 ///
3806 /// Values are considered constant in three cases:
3807 /// - A ConstantLike value.
3808 /// - A basic block argument from a different block.
3809 /// - A value defined outside of the block.
3810 ///
3811 /// If the padding value is not constant, an empty Value is returned.
3812 Value PadOp::getConstantPaddingValue() {
3813  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3814  if (!yieldOp)
3815  return {};
3816  Value padValue = yieldOp.getValue();
3817  // Check if yield value is a constant.
3818  if (matchPattern(padValue, m_Constant()))
3819  return padValue;
3820  // Check if yield value is defined inside the PadOp block.
3821  if (padValue.getParentBlock() == &getRegion().front())
3822  return {};
3823  // Else: Yield value defined outside of the PadOp block.
3824  return padValue;
3825 }
3826 
3827 OpFoldResult PadOp::fold(FoldAdaptor) {
3828  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3829  !getNofold())
3830  return getSource();
3831  return {};
3832 }
3833 
3834 //===----------------------------------------------------------------------===//
3835 // ParallelInsertSliceOp
3836 //===----------------------------------------------------------------------===//
3837 
3838 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3839  ParallelCombiningOpInterface parallelCombiningParent =
3840  getParallelCombiningParent();
3841  for (const auto &it :
3842  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3843  Operation &nextOp = it.value();
3844  if (&nextOp == getOperation())
3845  return parallelCombiningParent.getParentResult(it.index());
3846  }
3847  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3848 }
3849 
3850 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3851 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3852  Value source, Value dest,
3853  ArrayRef<OpFoldResult> offsets,
3854  ArrayRef<OpFoldResult> sizes,
3855  ArrayRef<OpFoldResult> strides,
3856  ArrayRef<NamedAttribute> attrs) {
3857  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3858  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3859  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3860  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3861  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3862  result.addAttributes(attrs);
3863  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3864  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3865  b.getDenseI64ArrayAttr(staticSizes),
3866  b.getDenseI64ArrayAttr(staticStrides));
3867 }
3868 
3869 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3870 /// packed into a Range vector.
3871 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3872  Value source, Value dest,
3873  ArrayRef<Range> ranges,
3874  ArrayRef<NamedAttribute> attrs) {
3875  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3876  build(b, result, source, dest, offsets, sizes, strides, attrs);
3877 }
3878 
3879 // Build a ParallelInsertSliceOp with dynamic entries.
3880 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3881  Value source, Value dest, ValueRange offsets,
3882  ValueRange sizes, ValueRange strides,
3883  ArrayRef<NamedAttribute> attrs) {
3884  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3885  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3886  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3887  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3888  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3889  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3890  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3891 }
3892 
3893 LogicalResult ParallelInsertSliceOp::verify() {
3894  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3895  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3896  << *(getOperation()->getParentOp());
3897 
3898  // Verify result type against inferred type.
3899  RankedTensorType expectedType;
3900  SliceVerificationResult result =
3901  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3902  getStaticSizes(), getStaticStrides(), &expectedType);
3903  if (result != SliceVerificationResult::Success)
3904  return produceSliceErrorMsg(result, *this, expectedType);
3905 
3906  // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3907  // to the destination tensor.
3909  getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3910  getStaticStrides(), /*generateErrorMessage=*/true);
3911  if (!boundsResult.isValid)
3912  return getOperation()->emitError(boundsResult.errorMessage);
3913 
3914  return success();
3915 }
3916 
3917 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3918  RewritePatternSet &results, MLIRContext *context) {
3919  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3920  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3921  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3922 }
3923 
3924 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3925  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3926 }
3927 
3928 //===----------------------------------------------------------------------===//
3929 // ScatterOp
3930 //===----------------------------------------------------------------------===//
3931 
3932 void ScatterOp::getAsmResultNames(
3933  function_ref<void(Value, StringRef)> setNameFn) {
3934  setNameFn(getResult(), "scatter");
3935 }
3936 
3937 LogicalResult ScatterOp::verify() {
3938  int64_t destRank = getDestType().getRank();
3939  ArrayRef<int64_t> scatterDims = getScatterDims();
3940  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3941  getIndicesType().getShape(), destRank,
3942  "scatter", "dest")))
3943  return failure();
3944 
3945  if (!getUnique())
3946  return emitOpError("requires 'unique' attribute to be set");
3947  // TODO: we could also check statically that there are fewer leading index
3948  // tensor dims than the dest dims. If this is not the case, the unique
3949  // attribute cannot be true.
3950 
3951  // Use the GatherOp::inferResultType on the `dest` type and verify the
3952  // expected type matches the source type.
3953  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3954  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3955  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3956  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3957  if (getSourceType() != expectedSourceType &&
3958  getSourceType() != expectedRankReducedSourceType) {
3959  return emitOpError("source type "
3960  "mismatch: "
3961  "expected ")
3962  << expectedSourceType << " or its rank-reduced variant "
3963  << expectedRankReducedSourceType << " (got: " << getSourceType()
3964  << ")";
3965  }
3966 
3967  return success();
3968 }
3969 
3970 //===----------------------------------------------------------------------===//
3971 // SplatOp
3972 //===----------------------------------------------------------------------===//
3973 
3974 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3975  Type aggregateType, ValueRange dynamicSizes) {
3976  build(builder, result, aggregateType, element, dynamicSizes);
3977 }
3978 
3979 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3980  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3981  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3982  build(builder, result, aggregateType, element, dynamicSizes);
3983 }
3984 
3985 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3986  ArrayRef<OpFoldResult> sizes) {
3987  SmallVector<int64_t> staticShape;
3988  SmallVector<Value> dynamicSizes;
3989  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3990  build(builder, result, element, staticShape, dynamicSizes);
3991 }
3992 
3993 void SplatOp::getAsmResultNames(
3994  function_ref<void(Value, StringRef)> setNameFn) {
3995  setNameFn(getResult(), "splat");
3996 }
3997 
3998 LogicalResult SplatOp::verify() {
3999  if (getType().getNumDynamicDims() != getDynamicSizes().size())
4000  return emitOpError("incorrect number of dynamic sizes, has ")
4001  << getDynamicSizes().size() << ", expected "
4002  << getType().getNumDynamicDims();
4003  return success();
4004 }
4005 
4006 LogicalResult
4008  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4009  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4010  unsigned ctr = 0;
4011  for (int64_t i = 0; i < getType().getRank(); ++i) {
4012  if (getType().isDynamicDim(i)) {
4013  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4014  } else {
4015  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4016  }
4017  }
4018  return success();
4019 }
4020 
4021 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4022  auto constOperand = adaptor.getInput();
4023  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4024  return {};
4025 
4026  // Do not fold if the splat is not statically shaped
4027  if (!getType().hasStaticShape())
4028  return {};
4029 
4030  // SplatElementsAttr::get treats single value for second arg as being a
4031  // splat.
4032  return SplatElementsAttr::get(getType(), {constOperand});
4033 }
4034 
4035 //===----------------------------------------------------------------------===//
4036 // Common Canonicalizers and Folders.
4037 //===----------------------------------------------------------------------===//
4038 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4039  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4040  // 2. Exclude DPS ops that are also LoopLike from this interface as they
4041  // might need special handling of attached regions.
4042  if (isa<InsertSliceOp>(op.getOperation()) ||
4043  isa<LoopLikeOpInterface>(op.getOperation()))
4044  return false;
4045 
4046  return hasFoldableTensorCastOperand(op);
4047 }
4048 
4049 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4050 /// the `tensor.cast` has source that is more static than the consuming op.
4051 ///
4052 /// Example:
4053 /// ```mlir
4054 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4055 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4056 /// ```
4057 ///
4058 /// folds into:
4059 ///
4060 /// ```mlir
4061 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4062 /// ```
4063 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4064 /// can add the pattern to their canonicalizers.
4066  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4068  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4069 
4070  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4071  PatternRewriter &rewriter) const override {
4072 
4073  // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4074  // for that instead.
4075  if (!foldTensorCastPrecondition(op) ||
4076  isa<linalg::RelayoutOpInterface>(*op))
4077  return failure();
4078 
4079  SmallVector<Type> newResultTypes(op->getResultTypes());
4080  SmallVector<Value> newOperands =
4081  getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4082 
4083  // Clone op
4084  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4085 
4086  SmallVector<Value, 4> replacements;
4087  replacements.reserve(newOp->getNumResults());
4088  for (auto [oldResult, newResult] :
4089  llvm::zip(op->getResults(), newOp->getResults())) {
4090  if (newResult.getType() != oldResult.getType()) {
4091  replacements.push_back(rewriter.create<tensor::CastOp>(
4092  op->getLoc(), oldResult.getType(), newResult));
4093  } else {
4094  replacements.push_back(newResult);
4095  }
4096  }
4097  rewriter.replaceOp(op, replacements);
4098 
4099  return success();
4100  }
4101 };
4102 
4103 //===----------------------------------------------------------------------===//
4104 // TensorDialect
4105 //===----------------------------------------------------------------------===//
4106 
4107 void TensorDialect::getCanonicalizationPatterns(
4108  RewritePatternSet &results) const {
4110 }
4111 
4112 //===----------------------------------------------------------------------===//
4113 // TableGen'd op method definitions
4114 //===----------------------------------------------------------------------===//
4115 
4116 #define GET_OP_CLASSES
4117 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:423
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
Definition: TensorOps.cpp:1374
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1560
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2439
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3209
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2891
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2731
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2753
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1818
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2840
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:187
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3205
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:143
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2911
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:2045
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:4038
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:366
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:362
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:376
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:53
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:214
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:225
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:394
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:363
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2691
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:356
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:372
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:325
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3182
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2778
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:131
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:64
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
Definition: TensorOps.cpp:1445
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:82
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:273
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:117
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4066
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4070
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2710
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2711
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2698
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2699
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.