MLIR  19.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include <algorithm>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::tensor;
37 
38 /// Materialize a single constant operation from a given attribute value with
39 /// the desired resultant type.
41  Attribute value, Type type,
42  Location loc) {
43  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
44  return op;
45  if (complex::ConstantOp::isBuildableWith(value, type))
46  return builder.create<complex::ConstantOp>(loc, type,
47  llvm::cast<ArrayAttr>(value));
48  return nullptr;
49 }
50 
52  int64_t dim) {
53  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
55  if (tensorType.isDynamicDim(dim))
56  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
57 
58  return builder.getIndexAttr(tensorType.getDimSize(dim));
59 }
60 
62  Location loc, Value value) {
63  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
65  for (int64_t i = 0; i < tensorType.getRank(); ++i)
66  result.push_back(getMixedSize(builder, loc, value, i));
67  return result;
68 }
69 
71  OpResult opResult) {
72  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
73  assert(tensorType && "expected tensor type");
74 
75  // If the op has a destination, it implements DestinationStyleOpInterface and
76  // we can query the destination operand from that interface.
77  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
78  if (destOp)
79  return destOp.getTiedOpOperand(opResult)->get();
80 
81  // Otherwise, create a new destination tensor with the same shape.
83  b.setInsertionPoint(opResult.getDefiningOp());
84 
85  // Compute sizes.
86  SmallVector<OpFoldResult> mixedSizes;
87  if (!tensorType.hasStaticShape()) {
88  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
89  ReifiedRankedShapedTypeDims reifiedShapes;
90  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
91  return failure();
92  mixedSizes = reifiedShapes[opResult.getResultNumber()];
93  } else {
94  // Static shape: Take static sizes directly.
95  for (int64_t sz : tensorType.getShape())
96  mixedSizes.push_back(b.getIndexAttr(sz));
97  }
98 
99  // Create empty tensor.
100  Value emptyTensor =
101  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
102  return emptyTensor;
103 }
104 
106  Operation *op,
107  SmallVector<Value> &result) {
108  for (OpResult opResult : op->getResults()) {
109  if (llvm::isa<TensorType>(opResult.getType())) {
110  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
111  if (failed(destination))
112  return failure();
113  result.push_back(*destination);
114  }
115  }
116  return success();
117 }
118 
120  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
121  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
122  return rtp1.getShape() == rtp2.getShape() &&
123  rtp1.getElementType() == rtp2.getElementType();
124  return false;
125  }
126  return tp1 == tp2; // default implementation
127 }
128 
129 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
130 /// rank-extending tensor.insert_slice op.
131 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
132  ArrayRef<OpFoldResult> mixedSizes) {
133  llvm::SmallBitVector droppedDims(mixedSizes.size());
134  int64_t shapePos = 0;
135 
136  for (const auto &size : enumerate(mixedSizes)) {
137  // Rank-reduced dims must have a static unit dimension.
138  bool isStaticUnitSize =
139  size.value().is<Attribute>() &&
140  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
141 
142  if (shapePos == static_cast<int64_t>(reducedShape.size())) {
143  // There are no more dims in the reduced shape. All remaining sizes must
144  // be rank-reduced dims.
145  assert(isStaticUnitSize && "expected unit dim");
146  droppedDims.set(size.index());
147  continue;
148  }
149 
150  // Dim is preserved if the size is not a static 1.
151  if (!isStaticUnitSize) {
152  ++shapePos;
153  continue;
154  }
155 
156  // Dim is preserved if the reduced shape dim is also 1.
157  if (reducedShape[shapePos] == 1) {
158  ++shapePos;
159  continue;
160  }
161 
162  // Otherwise: Dim is dropped.
163  droppedDims.set(size.index());
164  }
165 
166  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
167  "dimension mismatch");
168  return droppedDims;
169 }
170 
171 /// Given a ranked tensor type and a range of values that defines its dynamic
172 /// dimension sizes, turn all dynamic sizes that have a constant value into
173 /// static dimension sizes.
174 static RankedTensorType
175 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
176  SmallVector<Value> &foldedDynamicSizes) {
177  SmallVector<int64_t> staticShape(type.getShape().begin(),
178  type.getShape().end());
179  assert(type.getNumDynamicDims() ==
180  static_cast<int64_t>(dynamicSizes.size()) &&
181  "incorrect number of dynamic sizes");
182 
183  // Compute new static and dynamic sizes.
184  unsigned ctr = 0;
185  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
186  if (type.isDynamicDim(i)) {
187  Value dynamicSize = dynamicSizes[ctr++];
188  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
189  if (cst.has_value()) {
190  // Dynamic size must be non-negative.
191  if (cst.value() < 0) {
192  foldedDynamicSizes.push_back(dynamicSize);
193  continue;
194  }
195  staticShape[i] = *cst;
196  } else {
197  foldedDynamicSizes.push_back(dynamicSize);
198  }
199  }
200  }
201 
202  return RankedTensorType::get(staticShape, type.getElementType(),
203  type.getEncoding());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // BitcastOp
208 //===----------------------------------------------------------------------===//
209 
210 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
211  if (inputs.size() != 1 || outputs.size() != 1)
212  return false;
213  Type a = inputs.front(), b = outputs.front();
214  auto aT = dyn_cast<TensorType>(a);
215  auto bT = dyn_cast<TensorType>(b);
216  if (!aT || !bT)
217  return false;
218 
219  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
220  return false;
221 
222  return succeeded(verifyCompatibleShape(aT, bT));
223 }
224 
225 namespace {
226 
227 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
228 /// operation.
229 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
231 
232  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
233  PatternRewriter &rewriter) const final {
234  auto tensorBitcastOperand =
235  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
236  if (!tensorBitcastOperand)
237  return failure();
238 
239  auto resultType = cast<TensorType>(tensorBitcast.getType());
240  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
241  tensorBitcastOperand.getOperand());
242  return success();
243  }
244 };
245 
246 } // namespace
247 
248 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<ChainedTensorBitcast>(context);
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CastOp
255 //===----------------------------------------------------------------------===//
256 
257 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
258  setNameFn(getResult(), "cast");
259 }
260 
261 /// Returns true if `target` is a ranked tensor type that preserves static
262 /// information available in the `source` ranked tensor type.
264  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
265  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
266 
267  // Requires RankedTensorType.
268  if (!sourceType || !targetType)
269  return false;
270 
271  // Requires same elemental type.
272  if (sourceType.getElementType() != targetType.getElementType())
273  return false;
274 
275  // Requires same rank.
276  if (sourceType.getRank() != targetType.getRank())
277  return false;
278 
279  // If cast is towards more static sizes along any dimension, don't fold.
280  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
281  if (!ShapedType::isDynamic(std::get<0>(t)) &&
282  ShapedType::isDynamic(std::get<1>(t)))
283  return false;
284  }
285 
286  return true;
287 }
288 
289 /// Determines whether tensor::CastOp casts to a more dynamic version of the
290 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
291 /// implement canonicalization patterns for ops in different dialects that may
292 /// consume the results of tensor.cast operations. Such foldable tensor.cast
293 /// operations are typically inserted as `slice` ops and are canonicalized,
294 /// to preserve the type compatibility of their uses.
295 ///
296 /// Returns true when all conditions are met:
297 /// 1. source and result are ranked tensors with same element type and rank.
298 /// 2. the tensor type has more static information than the result
299 ///
300 /// Example:
301 /// ```mlir
302 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
303 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
304 /// ```
305 ///
306 /// folds into:
307 ///
308 /// ```mlir
309 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
310 /// ```
312  if (!castOp)
313  return false;
314 
315  // Can fold if the source of cast has at least as much static information as
316  // its results.
317  return preservesStaticInformation(castOp.getType(),
318  castOp.getSource().getType());
319 }
320 
321 /// Determines whether the tensor::CastOp casts to a more static version of the
322 /// source tensor. This is useful to fold into a producing op and implement
323 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
324 /// being from different dialects. Returns true when all conditions are met:
325 /// 1. source and result and ranked tensors with same element type and rank.
326 /// 2. the result type has more static information than the source.
327 ///
328 /// Example:
329 /// ```mlir
330 /// %1 = producer ... : tensor<?x?xf32>
331 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
332 /// ```
333 ///
334 /// can be canonicalized to :
335 ///
336 /// ```mlir
337 /// %2 = producer ... : tensor<8x16xf32>
338 /// ```
339 /// Not all ops might be canonicalizable this way, but for those that can be,
340 /// this method provides a check that it is worth doing the canonicalization.
342  if (!castOp)
343  return false;
344  return preservesStaticInformation(castOp.getSource().getType(),
345  castOp.getType());
346 }
347 
348 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
349 /// that can be folded.
351  bool folded = false;
352  for (OpOperand &operand : op->getOpOperands()) {
353  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
354  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
355  operand.set(castOp.getOperand());
356  folded = true;
357  }
358  }
359  return success(folded);
360 }
361 
362 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
363  if (inputs.size() != 1 || outputs.size() != 1)
364  return false;
365  Type a = inputs.front(), b = outputs.front();
366  auto aT = llvm::dyn_cast<TensorType>(a);
367  auto bT = llvm::dyn_cast<TensorType>(b);
368  if (!aT || !bT)
369  return false;
370 
371  if (aT.getElementType() != bT.getElementType())
372  return false;
373 
374  return succeeded(verifyCompatibleShape(aT, bT));
375 }
376 
377 /// Compute a TensorType that has the joined shape knowledge of the two
378 /// given TensorTypes. The element types need to match.
380  assert(one.getElementType() == two.getElementType());
381 
382  if (!one.hasRank())
383  return two;
384  if (!two.hasRank())
385  return one;
386 
387  int64_t rank = one.getRank();
388  if (rank != two.getRank())
389  return {};
390 
392  join.reserve(rank);
393  for (int64_t i = 0; i < rank; ++i) {
394  if (one.isDynamicDim(i)) {
395  join.push_back(two.getDimSize(i));
396  continue;
397  }
398  if (two.isDynamicDim(i)) {
399  join.push_back(one.getDimSize(i));
400  continue;
401  }
402  if (one.getDimSize(i) != two.getDimSize(i))
403  return {};
404  join.push_back(one.getDimSize(i));
405  }
406  return RankedTensorType::get(join, one.getElementType());
407 }
408 
409 namespace {
410 
411 /// Replaces chains of two tensor.cast operations by a single tensor.cast
412 /// operation if doing so does not remove runtime constraints.
413 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
415 
416  LogicalResult matchAndRewrite(CastOp tensorCast,
417  PatternRewriter &rewriter) const final {
418  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
419 
420  if (!tensorCastOperand)
421  return failure();
422 
423  auto sourceType =
424  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
425  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
426  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
427 
428  // We can remove the intermediate cast if joining all three produces the
429  // same result as just joining the source and result shapes.
430  auto firstJoin =
431  joinShapes(joinShapes(sourceType, intermediateType), resultType);
432 
433  // The join might not exist if the cast sequence would fail at runtime.
434  if (!firstJoin)
435  return failure();
436 
437  // The newJoin always exists if the above join exists, it might just contain
438  // less information. If so, we cannot drop the intermediate cast, as doing
439  // so would remove runtime checks.
440  auto newJoin = joinShapes(sourceType, resultType);
441  if (firstJoin != newJoin)
442  return failure();
443 
444  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
445  tensorCastOperand.getOperand());
446  return success();
447  }
448 };
449 
450 /// Fold tensor.cast into tesor.extract_slice producer.
451 /// Example:
452 /// ```
453 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
454 /// tensor<128x512xf32> to tensor<?x512xf32>
455 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
456 /// ```
457 /// ->
458 /// ```
459 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
460 /// tensor<128x512xf32> to tensor<16x512xf32>
461 /// ```
462 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
464 
465  LogicalResult matchAndRewrite(CastOp tensorCast,
466  PatternRewriter &rewriter) const final {
467  auto extractOperand =
468  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
469 
470  // Cannot fold cast to unranked tensor.
471  auto rankedResultType =
472  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
473  if (!rankedResultType)
474  return failure();
475 
476  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
477  rankedResultType.getShape() ==
478  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
479  .getShape())
480  return failure();
481 
482  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
483  auto dimMask = computeRankReductionMask(
484  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
485  size_t dimIndex = 0;
486  for (size_t i = 0, e = sizes.size(); i < e; i++) {
487  if (dimMask && dimMask->count(i))
488  continue;
489  int64_t dim = rankedResultType.getShape()[dimIndex++];
490  if (ShapedType::isDynamic(dim))
491  continue;
492  sizes[i] = rewriter.getIndexAttr(dim);
493  }
494 
495  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
496  tensorCast, rankedResultType, extractOperand.getSource(),
497  extractOperand.getMixedOffsets(), sizes,
498  extractOperand.getMixedStrides());
499  return success();
500  }
501 };
502 
503 } // namespace
504 
505 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
506  MLIRContext *context) {
507  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // ConcatOp
512 //===----------------------------------------------------------------------===//
513 
514 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
515  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
516  auto tensorTypes =
517  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
518  return llvm::cast<RankedTensorType>(type);
519  }));
520  int64_t concatRank = tensorTypes[0].getRank();
521 
522  // The concatenation dim must be in the range [0, rank).
523  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
524 
525  SmallVector<int64_t> sizes(concatRank);
526  for (int64_t i = 0, e = concatRank; i < e; ++i) {
527  if (i == dim)
528  continue;
529  SaturatedInteger size;
530  for (auto tensorType : tensorTypes)
531  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
532  sizes[i] = size.asInteger();
533  }
534  auto concatSize = SaturatedInteger::wrap(0);
535  for (auto tensorType : tensorTypes)
536  concatSize =
537  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
538  sizes[dim] = concatSize.asInteger();
539  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
540 }
541 
542 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
543  ValueRange inputs) {
544  FailureOr<RankedTensorType> resultType =
545  inferResultType(dim, inputs.getTypes());
546  assert(succeeded(resultType) && "failed to infer concatenation result type");
547  build(builder, result, *resultType, dim, inputs);
548 }
549 
551  if (getInputs().size() < 1)
552  return emitOpError("requires at least one input");
553 
555  for (auto input : getInputs())
556  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
557 
558  RankedTensorType resultType = getResultType();
559  int64_t resultRank = getRank();
560  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
561  return type.getRank() != resultRank;
562  }))
563  return emitOpError("rank of concatenated inputs must match result rank");
564 
565  Type resultElementType = resultType.getElementType();
566  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
567  return type.getElementType() != resultElementType;
568  }))
569  return emitOpError("inputs and result element type must match");
570 
571  int64_t dim = getDim();
572  if (dim >= resultRank)
573  return emitOpError("concatenation dim must be less than the tensor rank");
574 
575  SmallVector<int64_t> sizes(resultRank);
576  for (int64_t i = 0, e = resultRank; i < e; ++i) {
577  if (i == dim)
578  continue;
579  SaturatedInteger size;
580  for (auto tensorType : inputTypes) {
581  FailureOr<SaturatedInteger> maybeSize =
582  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
583  if (failed(maybeSize))
584  return emitOpError("static concatenation size mismatch along ")
585  << "non-concatenated dimension " << i;
586  size = *maybeSize;
587  }
588  sizes[i] = size.asInteger();
589  }
590  auto concatSize = SaturatedInteger::wrap(0);
591  for (auto tensorType : inputTypes)
592  concatSize =
593  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
594  sizes[dim] = concatSize.asInteger();
595  auto inferredResultType =
596  RankedTensorType::get(sizes, inputTypes[0].getElementType());
597 
598  for (auto [inferredSize, actualSize] :
599  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
600  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
601  ShapedType::isDynamic(actualSize);
602  if (!hasDynamic && inferredSize != actualSize)
603  return emitOpError("result type ")
604  << resultType << "does not match inferred shape "
605  << inferredResultType << " static sizes";
606  }
607 
608  return success();
609 }
610 
613  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
614  ValueRange inputs = getInputs();
615  int64_t dim = getDim();
616  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
617 
618  Value init = inputs[0];
619  int64_t rank = getType().getRank();
620 
621  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
622 
623  // Pre-populate the result sizes with as much static information as possible
624  // from the given result type, as well as the inferred result type, otherwise
625  // use the dim sizes from the first input.
626  for (int64_t i = 0; i < rank; ++i) {
627  if (i == dim)
628  continue;
629  if (!getType().isDynamicDim(i)) {
630  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
631  } else if (!inferredResultType.isDynamicDim(i)) {
632  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
633  builder, getLoc(),
634  builder.getIndexAttr(inferredResultType.getDimSize(i)));
635  } else {
636  reifiedReturnShapes[0][i] =
637  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
638  }
639  }
640 
641  if (getType().isDynamicDim(dim)) {
642  // Take the sum of the input sizes along the concatenated dim.
643  AffineExpr sum = builder.getAffineDimExpr(0);
644  SmallVector<OpFoldResult> sizes = {
645  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
646  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
647  sum = sum + builder.getAffineDimExpr(idx + 1);
648  sizes.push_back(
649  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
650  }
651  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
652  builder, getLoc(),
653  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
654  } else {
655  // If the result shape is static along the concatenated dim, use the static
656  // shape.
657  reifiedReturnShapes[0][dim] =
658  builder.getIndexAttr(getType().getDimSize(dim));
659  }
660  return success();
661 }
662 
663 void ConcatOp::getAsmResultNames(
664  function_ref<void(Value, StringRef)> setNameFn) {
665  setNameFn(getResult(), "concat");
666 }
667 
668 OpFoldResult ConcatOp::fold(FoldAdaptor) {
669  ValueRange inputs = getInputs();
670  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
671  return inputs[0];
672  return {};
673 }
674 
675 namespace {
676 /// Fold a concat op with a single input to a cast.
677 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
679 
680  LogicalResult matchAndRewrite(ConcatOp concatOp,
681  PatternRewriter &rewriter) const override {
682  if (concatOp.getInputs().size() != 1)
683  return failure();
684  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
685  concatOp.getInputs()[0]);
686  return success();
687  }
688 };
689 } // namespace
690 
691 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
692  MLIRContext *context) {
693  results.add<SingleInputConcatOp>(context);
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // DimOp
698 //===----------------------------------------------------------------------===//
699 
700 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
701  setNameFn(getResult(), "dim");
702 }
703 
704 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
705  int64_t index) {
706  auto loc = result.location;
707  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
708  build(builder, result, source, indexValue);
709 }
710 
711 std::optional<int64_t> DimOp::getConstantIndex() {
712  return getConstantIntValue(getIndex());
713 }
714 
715 Speculation::Speculatability DimOp::getSpeculatability() {
716  auto constantIndex = getConstantIndex();
717  if (!constantIndex)
719 
720  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
721  if (!rankedSourceType)
723 
724  if (rankedSourceType.getRank() <= constantIndex)
726 
728 }
729 
730 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
731  // All forms of folding require a known index.
732  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
733  if (!index)
734  return {};
735 
736  // Folding for unranked types (UnrankedTensorType) is not supported.
737  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
738  if (!tensorType)
739  return {};
740 
741  // Out of bound indices produce undefined behavior but are still valid IR.
742  // Don't choke on them.
743  int64_t indexVal = index.getInt();
744  if (indexVal < 0 || indexVal >= tensorType.getRank())
745  return {};
746 
747  // Fold if the shape extent along the given index is known.
748  if (!tensorType.isDynamicDim(index.getInt())) {
749  Builder builder(getContext());
750  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
751  }
752 
753  Operation *definingOp = getSource().getDefiningOp();
754 
755  // Fold dim to the operand of tensor.generate.
756  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
757  auto resultType =
758  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
759  // The case where the type encodes the size of the dimension is handled
760  // above.
761  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
762 
763  // Find the operand of the fromElements that corresponds to this index.
764  auto dynExtents = fromElements.getDynamicExtents().begin();
765  for (auto dim : resultType.getShape().take_front(index.getInt()))
766  if (ShapedType::isDynamic(dim))
767  dynExtents++;
768 
769  return Value{*dynExtents};
770  }
771 
772  // The size at the given index is now known to be a dynamic size.
773  unsigned unsignedIndex = index.getValue().getZExtValue();
774 
775  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
776  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
777  // `resolve-shaped-type-result-dims` pass.
778  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
779  sliceOp.isDynamicSize(unsignedIndex)) {
780  return {sliceOp.getDynamicSize(unsignedIndex)};
781  }
782  }
783 
784  // dim(cast) -> dim
785  if (succeeded(foldTensorCast(*this)))
786  return getResult();
787 
788  return {};
789 }
790 
791 namespace {
792 /// Fold dim of a cast into the dim of the source of the tensor cast.
793 struct DimOfCastOp : public OpRewritePattern<DimOp> {
795 
796  LogicalResult matchAndRewrite(DimOp dimOp,
797  PatternRewriter &rewriter) const override {
798  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
799  if (!castOp)
800  return failure();
801  Value newSource = castOp.getOperand();
802  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
803  return success();
804  }
805 };
806 
807 /// Fold dim of a destination passing style op into the dim of the corresponding
808 /// init.
809 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
811 
812  LogicalResult matchAndRewrite(DimOp dimOp,
813  PatternRewriter &rewriter) const override {
814  auto source = dimOp.getSource();
815  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
816  if (!destOp)
817  return failure();
818 
819  auto resultIndex = source.cast<OpResult>().getResultNumber();
820  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
821 
822  rewriter.modifyOpInPlace(
823  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
824  return success();
825  }
826 };
827 } // namespace
828 
829 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
830  MLIRContext *context) {
831  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // EmptyOp
836 //===----------------------------------------------------------------------===//
837 
838 void EmptyOp::build(OpBuilder &builder, OperationState &result,
839  ArrayRef<int64_t> staticShape, Type elementType,
840  Attribute encoding) {
841  assert(all_of(staticShape,
842  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
843  "expected only static sizes");
844  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
845 }
846 
847 void EmptyOp::build(OpBuilder &builder, OperationState &result,
848  ArrayRef<int64_t> staticShape, Type elementType,
849  ValueRange dynamicSizes, Attribute encoding) {
850  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
851  build(builder, result, tensorType, dynamicSizes);
852 }
853 
854 void EmptyOp::build(OpBuilder &builder, OperationState &result,
855  ArrayRef<OpFoldResult> sizes, Type elementType,
856  Attribute encoding) {
857  SmallVector<int64_t> staticShape;
858  SmallVector<Value> dynamicSizes;
859  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
860  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
861 }
862 
864  if (getType().getNumDynamicDims() !=
865  static_cast<int64_t>(getDynamicSizes().size()))
866  return emitOpError("incorrect number of dynamic sizes, has ")
867  << getDynamicSizes().size() << ", expected "
868  << getType().getNumDynamicDims();
869  return success();
870 }
871 
874  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
875  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
876  unsigned ctr = 0;
877  for (int64_t i = 0; i < getType().getRank(); ++i) {
878  if (getType().isDynamicDim(i)) {
879  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
880  } else {
881  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
882  }
883  }
884  return success();
885 }
886 
887 Value EmptyOp::getDynamicSize(unsigned idx) {
888  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
889  unsigned ctr = 0;
890  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
891  if (getType().isDynamicDim(i))
892  ++ctr;
893  return getDynamicSizes()[ctr];
894 }
895 
898  unsigned ctr = 0;
899  OpBuilder b(getContext());
900  for (int64_t i = 0; i < getType().getRank(); ++i) {
901  if (getType().isDynamicDim(i)) {
902  result.push_back(getDynamicSizes()[ctr++]);
903  } else {
904  result.push_back(b.getIndexAttr(getType().getShape()[i]));
905  }
906  }
907  return result;
908 }
909 
910 namespace {
911 /// Change the type of the result of a `tensor.empty` by making the result
912 /// type statically sized along dimensions that in the original operation were
913 /// defined as dynamic, but the size was defined using a `constant` op. For
914 /// example
915 ///
916 /// %c5 = arith.constant 5: index
917 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
918 ///
919 /// to
920 ///
921 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
922 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
924 
925  LogicalResult matchAndRewrite(EmptyOp op,
926  PatternRewriter &rewriter) const override {
927  SmallVector<Value> foldedDynamicSizes;
928  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
929  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
930 
931  // Stop here if no dynamic size was promoted to static.
932  if (foldedTensorType == op.getType())
933  return failure();
934 
935  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
936  foldedDynamicSizes);
937  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
938  return success();
939  }
940 };
941 
942 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
944 
945  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
946  PatternRewriter &rewriter) const override {
947  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
948  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
949  if (!emptyTensorOp || !maybeConstantIndex)
950  return failure();
951  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
952  return failure();
953  rewriter.replaceOp(dimOp,
954  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
955  return success();
956  }
957 };
958 
959 /// Canonicalize
960 ///
961 /// ```mlir
962 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
963 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
964 /// ```
965 ///
966 /// into
967 ///
968 /// ```mlir
969 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
970 /// ```
971 ///
972 /// This assumes the input program is correct in terms of its shape. So it is
973 /// safe to assume that `%d0` is in fact 4.
974 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
976 
977  LogicalResult matchAndRewrite(CastOp castOp,
978  PatternRewriter &rewriter) const override {
979  if (!canFoldIntoProducerOp(castOp))
980  return failure();
981  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
982  if (!producer)
983  return failure();
984 
985  auto resultType =
986  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
987  ArrayRef<int64_t> resultShape = resultType.getShape();
988  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
989  SmallVector<OpFoldResult> newMixedSizes;
990  newMixedSizes.reserve(currMixedSizes.size());
991  assert(resultShape.size() == currMixedSizes.size() &&
992  "mismatch in result shape and sizes of empty op");
993  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
994  int64_t newDim = std::get<0>(it);
995  OpFoldResult currDim = std::get<1>(it);
996  // Case 1: The empty tensor dim is static. Check that the tensor cast
997  // result dim matches.
998  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
999  if (ShapedType::isDynamic(newDim) ||
1000  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1001  // Something is off, the cast result shape cannot be more dynamic
1002  // than the empty tensor result shape (enforced by
1003  // `canFoldIntoProducer`). Abort for now.
1004  return rewriter.notifyMatchFailure(
1005  producer, "mismatch in static value of shape of empty tensor "
1006  "result and cast result");
1007  }
1008  newMixedSizes.push_back(attr);
1009  continue;
1010  }
1011 
1012  // Case 2 : The tensor cast shape is static, but empty tensor result
1013  // shape is dynamic.
1014  if (!ShapedType::isDynamic(newDim)) {
1015  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1016  continue;
1017  }
1018 
1019  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1020  // shape is dynamic. Use the dynamic value from the empty tensor op.
1021  newMixedSizes.push_back(currDim);
1022  }
1023 
1024  // TODO: Do not drop tensor encoding.
1025  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1026  resultType.getElementType());
1027  return success();
1028  }
1029 };
1030 
1031 } // namespace
1032 
1033 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1034  MLIRContext *context) {
1035  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1036  ReplaceEmptyTensorStaticShapeDims>(context);
1037 }
1038 
1039 /// Try to remove a tensor operation if it would only reshape a constant.
1040 /// Removes the op and replaces the constant with a new constant of the result
1041 /// shape.
1043  TensorType result) {
1044  if (source && source.isSplat() && result.hasStaticShape())
1045  return source.resizeSplat(result);
1046 
1047  return {};
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // ExtractOp
1052 //===----------------------------------------------------------------------===//
1053 
1054 namespace {
1055 
1056 /// Canonicalizes the pattern of the form
1057 ///
1058 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1059 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1060 ///
1061 /// to
1062 ///
1063 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1064 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1066 
1067  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1068  PatternRewriter &rewriter) const final {
1069  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1070  if (!tensorCast)
1071  return failure();
1072  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1073  return failure();
1074  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1075  extract, tensorCast.getSource(), extract.getIndices());
1076  return success();
1077  }
1078 };
1079 
1080 } // namespace
1081 
1082 void ExtractOp::getAsmResultNames(
1083  function_ref<void(Value, StringRef)> setNameFn) {
1084  setNameFn(getResult(), "extracted");
1085 }
1086 
1088  // Verify the # indices match if we have a ranked type.
1089  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1090  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1091  return emitOpError("incorrect number of indices for extract_element");
1092  return success();
1093 }
1094 
1095 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1096  // If this is a splat elements attribute, simply return the value. All of
1097  // the elements of a splat attribute are the same.
1098  if (Attribute tensor = adaptor.getTensor())
1099  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1100  return splatTensor.getSplatValue<Attribute>();
1101 
1102  // Collect the constant indices into the tensor.
1103  SmallVector<uint64_t, 8> indices;
1104  for (Attribute indice : adaptor.getIndices()) {
1105  if (!indice || !llvm::isa<IntegerAttr>(indice))
1106  return {};
1107  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1108  }
1109 
1110  // Fold extract(from_elements(...)).
1111  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1112  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1113  auto rank = tensorType.getRank();
1114  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1115  "rank mismatch");
1116  int flatIndex = 0;
1117  int stride = 1;
1118  for (int i = rank - 1; i >= 0; --i) {
1119  flatIndex += indices[i] * stride;
1120  stride *= tensorType.getDimSize(i);
1121  }
1122  // Prevent out of bounds accesses. This can happen in invalid code that
1123  // will never execute.
1124  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1125  flatIndex < 0)
1126  return {};
1127  return fromElementsOp.getElements()[flatIndex];
1128  }
1129 
1130  // If this is an elements attribute, query the value at the given indices.
1131  if (Attribute tensor = adaptor.getTensor()) {
1132  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1133  if (elementsAttr && elementsAttr.isValidIndex(indices))
1134  return elementsAttr.getValues<Attribute>()[indices];
1135  }
1136 
1137  return {};
1138 }
1139 
1140 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1141  MLIRContext *context) {
1142  results.add<ExtractFromTensorCast>(context);
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // FromElementsOp
1147 //===----------------------------------------------------------------------===//
1148 
1149 void FromElementsOp::getAsmResultNames(
1150  function_ref<void(Value, StringRef)> setNameFn) {
1151  setNameFn(getResult(), "from_elements");
1152 }
1153 
1154 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1155  ValueRange elements) {
1156  assert(!elements.empty() && "expected at least one element");
1157  Type resultType = RankedTensorType::get(
1158  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1159  build(builder, result, resultType, elements);
1160 }
1161 
1162 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1163  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1164  return DenseElementsAttr::get(getType(), adaptor.getElements());
1165  return {};
1166 }
1167 
1168 namespace {
1169 
1170 // Pushes the index_casts that occur before extractions to after the extract.
1171 // This minimizes type conversion in some cases and enables the extract
1172 // canonicalizer. This changes:
1173 //
1174 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1175 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1176 //
1177 // to the following:
1178 //
1179 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1180 // %cast = arith.index_cast %extract : i32 to index
1181 //
1182 // to just %element.
1183 //
1184 // Consider expanding this to a template and handle all tensor cast
1185 // operations.
1186 struct ExtractElementFromIndexCast
1187  : public OpRewritePattern<tensor::ExtractOp> {
1189 
1190  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1191  PatternRewriter &rewriter) const final {
1192  Location loc = extract.getLoc();
1193  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1194  if (!indexCast)
1195  return failure();
1196 
1197  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1198 
1199  auto newExtract = rewriter.create<tensor::ExtractOp>(
1200  loc, elementTy, indexCast.getIn(), extract.getIndices());
1201 
1202  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1203  newExtract);
1204 
1205  return success();
1206  }
1207 };
1208 
1209 } // namespace
1210 
1211 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1212  MLIRContext *context) {
1213  results.add<ExtractElementFromIndexCast>(context);
1214 }
1215 
1216 //===----------------------------------------------------------------------===//
1217 // GatherOp
1218 //===----------------------------------------------------------------------===//
1219 
1220 void GatherOp::getAsmResultNames(
1221  function_ref<void(Value, StringRef)> setNameFn) {
1222  setNameFn(getResult(), "gather");
1223 }
1224 
1225 /// Return the inferred result type for a gatherOp where:
1226 /// - sourceType is the type of the source tensor gathered from
1227 /// - indicesType is the type of the indices used to gather
1228 /// - gatherDims are the dims along which the gather occurs.
1229 /// Return a full rank or ranked-reduced variant of the type depending on
1230 /// the value of rankReduced.
1231 ///
1232 /// The leading dimensions of the index tensor give the result tensor its
1233 /// leading dimensions.
1234 /// The trailing dimensions of the result tensor are obtained from the source
1235 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1236 /// rankedReduced is false), or skipping them (otherwise).
1237 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1238  RankedTensorType indicesType,
1239  ArrayRef<int64_t> gatherDims,
1240  bool rankReduced) {
1241  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1242  resultShape.reserve(resultShape.size() + sourceType.getRank());
1243  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1244  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1245  if (!rankReduced)
1246  resultShape.push_back(1);
1247  continue;
1248  }
1249  resultShape.push_back(sourceType.getDimSize(idx));
1250  }
1251  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1252 }
1253 
1254 static LogicalResult
1256  StringRef gatherOrScatter, StringRef sourceOrDest) {
1257  if (dims.empty())
1258  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1259 
1260  int64_t numGatherDims = dims.size();
1261  if (numGatherDims > rank)
1262  return op->emitOpError(gatherOrScatter)
1263  << "_dims overflow " << sourceOrDest << " rank";
1264  for (int64_t val : dims) {
1265  if (val < 0)
1266  return op->emitOpError(gatherOrScatter)
1267  << "_dims value must be non-negative";
1268  if (val >= rank)
1269  return op->emitOpError(gatherOrScatter)
1270  << "_dims value must be smaller than " << sourceOrDest << " rank";
1271  }
1272  for (int64_t i = 1; i < numGatherDims; ++i) {
1273  if (dims[i - 1] >= dims[i])
1274  return op->emitOpError(gatherOrScatter)
1275  << "_dims values must be strictly increasing";
1276  }
1277  return success();
1278 }
1279 
1281  int64_t sourceRank = getSourceType().getRank();
1282  ArrayRef<int64_t> gatherDims = getGatherDims();
1283  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1284  "gather", "source")))
1285  return failure();
1286 
1287  RankedTensorType expectedResultType = GatherOp::inferResultType(
1288  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1289  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1290  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1291  if (getResultType() != expectedResultType &&
1292  getResultType() != expectedRankReducedResultType) {
1293  return emitOpError("result type "
1294  "mismatch: "
1295  "expected ")
1296  << expectedResultType << " or its rank-reduced variant "
1297  << expectedRankReducedResultType << " (got: " << getResultType()
1298  << ")";
1299  }
1300 
1301  return success();
1302 }
1303 
1304 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1305  if (OpFoldResult reshapedSource = reshapeConstantSource(
1306  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1307  getResult().getType()))
1308  return reshapedSource;
1309  return {};
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // InsertOp
1314 //===----------------------------------------------------------------------===//
1315 
1316 void InsertOp::getAsmResultNames(
1317  function_ref<void(Value, StringRef)> setNameFn) {
1318  setNameFn(getResult(), "inserted");
1319 }
1320 
1322  // Verify the # indices match if we have a ranked type.
1323  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1324  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1325  return emitOpError("incorrect number of indices");
1326  return success();
1327 }
1328 
1329 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1330  Attribute scalar = adaptor.getScalar();
1331  Attribute dest = adaptor.getDest();
1332  if (scalar && dest)
1333  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1334  if (scalar == splatDest.getSplatValue<Attribute>())
1335  return dest;
1336  return {};
1337 }
1338 
1339 //===----------------------------------------------------------------------===//
1340 // GenerateOp
1341 //===----------------------------------------------------------------------===//
1342 
1343 void GenerateOp::getAsmResultNames(
1344  function_ref<void(Value, StringRef)> setNameFn) {
1345  setNameFn(getResult(), "generated");
1346 }
1347 
1349  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1350  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1351  int idx = 0;
1352  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1353  if (getType().isDynamicDim(dim)) {
1354  reifiedReturnShapes[0][dim] = getOperand(idx++);
1355  } else {
1356  reifiedReturnShapes[0][dim] =
1357  builder.getIndexAttr(getType().getDimSize(dim));
1358  }
1359  }
1360  return success();
1361 }
1362 
1364  // Ensure that the tensor type has as many dynamic dimensions as are
1365  // specified by the operands.
1366  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1367  if (getNumOperands() != resultType.getNumDynamicDims())
1368  return emitError("must have as many index operands as dynamic extents "
1369  "in the result type");
1370  return success();
1371 }
1372 
1373 LogicalResult GenerateOp::verifyRegions() {
1374  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1375  // Ensure that region arguments span the index space.
1376  if (!llvm::all_of(getBody().getArgumentTypes(),
1377  [](Type ty) { return ty.isIndex(); }))
1378  return emitError("all body arguments must be index");
1379  if (getBody().getNumArguments() != resultTy.getRank())
1380  return emitError("must have one body argument per input dimension");
1381 
1382  // Ensure that the region yields an element of the right type.
1383  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1384 
1385  if (yieldOp.getValue().getType() != resultTy.getElementType())
1386  return emitOpError(
1387  "body must be terminated with a `yield` operation of the tensor "
1388  "element type");
1389 
1390  return success();
1391 }
1392 
1393 void GenerateOp::build(
1394  OpBuilder &b, OperationState &result, Type resultTy,
1395  ValueRange dynamicExtents,
1396  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1397  build(b, result, resultTy, dynamicExtents);
1398 
1399  // Build and populate body.
1400  OpBuilder::InsertionGuard guard(b);
1401  Region *bodyRegion = result.regions.front().get();
1402  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1403  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1404  SmallVector<Location, 2> argumentLocs(rank, result.location);
1405  Block *bodyBlock =
1406  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1407  bodyBuilder(b, result.location, bodyBlock->getArguments());
1408 }
1409 
1410 namespace {
1411 
1412 /// Canonicalizes tensor.generate operations with a constant
1413 /// operand into the equivalent operation with the operand expressed in the
1414 /// result type, instead. We also insert a type cast to make sure that the
1415 /// resulting IR is still well-typed.
1416 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1418 
1419  LogicalResult matchAndRewrite(GenerateOp generateOp,
1420  PatternRewriter &rewriter) const final {
1421  SmallVector<Value> foldedDynamicSizes;
1422  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1423  generateOp.getType(), generateOp.getDynamicExtents(),
1424  foldedDynamicSizes);
1425 
1426  // Stop here if no dynamic size was promoted to static.
1427  if (foldedTensorType == generateOp.getType())
1428  return failure();
1429 
1430  auto loc = generateOp.getLoc();
1431  auto newOp =
1432  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1433  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1434  newOp.getBody().begin());
1435  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1436  generateOp.getType(), newOp);
1437  return success();
1438  }
1439 };
1440 
1441 /// Canonicalizes the pattern of the form
1442 ///
1443 /// %tensor = tensor.generate %x {
1444 /// ^bb0(%arg0: index):
1445 /// <computation>
1446 /// yield %1 : index
1447 /// } : tensor<?xindex>
1448 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1449 ///
1450 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1451 /// tensor.generate operation has no side-effects.
1452 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1454 
1455  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1456  PatternRewriter &rewriter) const final {
1457  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1458  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1459  return failure();
1460 
1461  IRMapping mapping;
1462  Block *body = &tensorFromElements.getBody().front();
1463  mapping.map(body->getArguments(), extract.getIndices());
1464  for (auto &op : body->without_terminator())
1465  rewriter.clone(op, mapping);
1466 
1467  auto yield = cast<YieldOp>(body->getTerminator());
1468 
1469  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1470  return success();
1471  }
1472 };
1473 
1474 } // namespace
1475 
1476 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1477  MLIRContext *context) {
1478  // TODO: Move extract pattern to tensor::ExtractOp.
1479  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1480 }
1481 
1482 //===----------------------------------------------------------------------===//
1483 // RankOp
1484 //===----------------------------------------------------------------------===//
1485 
1486 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1487  setNameFn(getResult(), "rank");
1488 }
1489 
1490 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1491  // Constant fold rank when the rank of the operand is known.
1492  auto type = getOperand().getType();
1493  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1494  if (shapedType && shapedType.hasRank())
1495  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1496  return IntegerAttr();
1497 }
1498 
1499 //===----------------------------------------------------------------------===//
1500 // ReshapeOp
1501 //===----------------------------------------------------------------------===//
1502 
1503 void ReshapeOp::getAsmResultNames(
1504  function_ref<void(Value, StringRef)> setNameFn) {
1505  setNameFn(getResult(), "reshape");
1506 }
1507 
1508 static int64_t getNumElements(ShapedType type) {
1509  int64_t numElements = 1;
1510  for (auto dim : type.getShape())
1511  numElements *= dim;
1512  return numElements;
1513 }
1514 
1516  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1517  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1518 
1519  if (operandType.getElementType() != resultType.getElementType())
1520  return emitOpError("element types of source and destination tensor "
1521  "types should be the same");
1522 
1523  int64_t shapeSize =
1524  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1525  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1526  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1527 
1528  if (resultRankedType) {
1529  if (operandRankedType && resultRankedType.hasStaticShape() &&
1530  operandRankedType.hasStaticShape()) {
1531  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1532  return emitOpError("source and destination tensor should have the "
1533  "same number of elements");
1534  }
1535  if (ShapedType::isDynamic(shapeSize))
1536  return emitOpError("cannot use shape operand with dynamic length to "
1537  "reshape to statically-ranked tensor type");
1538  if (shapeSize != resultRankedType.getRank())
1539  return emitOpError(
1540  "length of shape operand differs from the result's tensor rank");
1541  }
1542  return success();
1543 }
1544 
1545 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1546  if (OpFoldResult reshapedSource = reshapeConstantSource(
1547  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1548  getResult().getType()))
1549  return reshapedSource;
1550  return {};
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // Reassociative reshape ops
1555 //===----------------------------------------------------------------------===//
1556 
1557 void CollapseShapeOp::getAsmResultNames(
1558  function_ref<void(Value, StringRef)> setNameFn) {
1559  setNameFn(getResult(), "collapsed");
1560 }
1561 
1562 void ExpandShapeOp::getAsmResultNames(
1563  function_ref<void(Value, StringRef)> setNameFn) {
1564  setNameFn(getResult(), "expanded");
1565 }
1566 
1567 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1568  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1569  "invalid resultDim");
1570  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1571  if (llvm::is_contained(it.value(), resultDim))
1572  return it.index();
1573  llvm_unreachable("could not find reassociation group");
1574 }
1575 
1576 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1577  return getSymbolLessAffineMaps(getReassociationExprs());
1578 }
1579 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1581  getReassociationIndices());
1582 }
1583 
1584 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1585  return getSymbolLessAffineMaps(getReassociationExprs());
1586 }
1587 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1589  getReassociationIndices());
1590 }
1591 
1592 RankedTensorType CollapseShapeOp::inferCollapsedType(
1593  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1594  return inferCollapsedType(
1596  type.getContext(), reassociation)));
1597 }
1598 
1599 /// Compute the RankedTensorType obtained by applying `reassociation` to
1600 /// `type`.
1601 RankedTensorType
1602 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1603  ArrayRef<AffineMap> reassociation) {
1604  auto shape = type.getShape();
1605  SmallVector<int64_t, 4> newShape;
1606  newShape.reserve(reassociation.size());
1607 
1608  // Use the fact that reassociation is valid to simplify the logic: only use
1609  // each map's rank.
1610  assert(isReassociationValid(reassociation) && "invalid reassociation");
1611  unsigned currentDim = 0;
1612  for (AffineMap m : reassociation) {
1613  unsigned dim = m.getNumResults();
1614  auto band = shape.slice(currentDim, dim);
1615  int64_t size = 1;
1616  if (llvm::is_contained(band, ShapedType::kDynamic))
1617  size = ShapedType::kDynamic;
1618  else
1619  for (unsigned d = 0; d < dim; ++d)
1620  size *= shape[currentDim + d];
1621  newShape.push_back(size);
1622  currentDim += dim;
1623  }
1624 
1625  return RankedTensorType::get(newShape, type.getElementType());
1626 }
1627 
1628 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1629  ArrayRef<ReassociationIndices> reassociation,
1630  ArrayRef<NamedAttribute> attrs) {
1631  auto resultType = inferCollapsedType(
1632  llvm::cast<RankedTensorType>(src.getType()),
1634  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1635  build(b, result, resultType, src, attrs);
1636  result.addAttribute(getReassociationAttrStrName(),
1637  getReassociationIndicesAttribute(b, reassociation));
1638 }
1639 
1640 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1641  TensorReshapeOp, ExpandShapeOp>::value>
1642 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1643  RankedTensorType expandedType,
1644  RankedTensorType collapsedType) {
1645  if (failed(
1646  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1647  return failure();
1648 
1649  auto maps = op.getReassociationMaps();
1650  RankedTensorType expectedType =
1651  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1652  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1653  return op.emitOpError("expected collapsed type to be ")
1654  << expectedType << ", but got " << collapsedType;
1655  return success();
1656 }
1657 
1659  auto srcType = getSrcType();
1660  auto resultType = getResultType();
1661  if (srcType.getRank() >= resultType.getRank())
1662  return emitOpError("expected rank expansion, but found source rank ")
1663  << srcType.getRank() << " >= result rank " << resultType.getRank();
1664 
1665  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
1666 }
1667 
1669  auto srcType = getSrcType();
1670  auto resultType = getResultType();
1671  if (srcType.getRank() <= resultType.getRank())
1672  return emitOpError("expected rank reduction, but found source rank ")
1673  << srcType.getRank() << " <= result rank " << resultType.getRank();
1674 
1675  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1676 }
1677 
1678 namespace {
1679 /// Reshape of a splat constant can be replaced with a constant of the result
1680 /// type.
1681 template <typename TensorReshapeOp>
1682 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1684  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1685  PatternRewriter &rewriter) const override {
1686  DenseElementsAttr attr;
1687  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1688  return failure();
1689  if (!attr || !attr.isSplat())
1690  return failure();
1692  reshapeOp.getResultType(), attr.getRawData());
1693  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1694  return success();
1695  }
1696 };
1697 
1698 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1699 template <typename TensorReshapeOp>
1700 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1701 public:
1703 
1704  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1705  PatternRewriter &rewriter) const override {
1706  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1707  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1708  return failure();
1709 
1710  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1711  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1712  return success();
1713  }
1714 };
1715 
1716 /// Reshape of a FromElements can be replaced with a FromElements of the
1717 /// result type
1718 template <typename TensorReshapeOp>
1719 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1721  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1722  PatternRewriter &rewriter) const override {
1723  auto fromElements =
1724  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1725  if (!fromElements)
1726  return failure();
1727 
1728  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1729 
1730  if (!shapedTy.hasStaticShape())
1731  return failure();
1732 
1733  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1734  fromElements.getElements());
1735  return success();
1736  }
1737 };
1738 
1739 // Fold CastOp into CollapseShapeOp when adding static information.
1740 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1742 
1743  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1744  PatternRewriter &rewriter) const override {
1745  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1746  if (!tensor::canFoldIntoConsumerOp(castOp))
1747  return failure();
1748 
1749  RankedTensorType srcType =
1750  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1751  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1752  srcType, collapseShapeOp.getReassociationMaps());
1753 
1754  if (newResultType == collapseShapeOp.getResultType()) {
1755  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1756  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1757  });
1758  } else {
1759  auto newOp = rewriter.create<CollapseShapeOp>(
1760  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1761  collapseShapeOp.getReassociation());
1762  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1763  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1764  }
1765  return success();
1766  }
1767 };
1768 
1769 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1771 
1772  LogicalResult matchAndRewrite(DimOp dimOp,
1773  PatternRewriter &rewriter) const override {
1774  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1775  if (!expandShapeOp)
1776  return failure();
1777 
1778  // Only constant dimension values are supported.
1779  std::optional<int64_t> dim = dimOp.getConstantIndex();
1780  if (!dim.has_value())
1781  return failure();
1782 
1783  // Skip static dims. These are folded to constant ops.
1784  RankedTensorType resultType = expandShapeOp.getResultType();
1785  if (!resultType.isDynamicDim(*dim))
1786  return failure();
1787 
1788  // Find reassociation group that contains this result dimension.
1789  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1790 
1791  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1792  // ExpandShapeOp would be ambiguous.)
1793  int64_t product = 1;
1794  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1795  for (int64_t d : grp) {
1796  if (d != dim) {
1797  assert(!resultType.isDynamicDim(d) && "expected static dim");
1798  product *= resultType.getDimSize(d);
1799  }
1800  }
1801 
1802  // result dim size = src dim size / (product(other dims in reassoc group))
1803  Value srcDimSz =
1804  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1805  AffineExpr expr;
1806  bindSymbols(dimOp.getContext(), expr);
1807  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1808  dimOp, expr.floorDiv(product), srcDimSz);
1809  return success();
1810  }
1811 };
1812 
1813 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1815 
1816  LogicalResult matchAndRewrite(DimOp dimOp,
1817  PatternRewriter &rewriter) const override {
1818  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1819  if (!collapseShapeOp)
1820  return failure();
1821 
1822  // Only constant dimension values are supported.
1823  std::optional<int64_t> dim = dimOp.getConstantIndex();
1824  if (!dim.has_value())
1825  return failure();
1826 
1827  // Skip static dims. These are folded to constant ops.
1828  RankedTensorType resultType = collapseShapeOp.getResultType();
1829  if (!resultType.isDynamicDim(*dim))
1830  return failure();
1831 
1832  // Get reassociation group of the result dimension.
1833  ReassociationIndices group =
1834  collapseShapeOp.getReassociationIndices()[*dim];
1835 
1836  // result dim size = product(dims in reassoc group)
1837  SmallVector<Value> srcDimSizes;
1840  for (const auto &it : llvm::enumerate(group)) {
1841  srcDimSizes.push_back(rewriter.create<DimOp>(
1842  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1843  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1844  product = product ? product * syms.back() : syms.back();
1845  }
1846  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1847  srcDimSizes);
1848  return success();
1849  }
1850 };
1851 } // namespace
1852 
1853 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1854  MLIRContext *context) {
1857  FoldReshapeWithConstant<ExpandShapeOp>,
1858  FoldReshapeWithSplat<ExpandShapeOp>,
1859  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1860  FoldDimOfCollapseShape>(context);
1861 }
1862 
1863 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1864  MLIRContext *context) {
1865  results
1868  FoldReshapeWithConstant<CollapseShapeOp>,
1869  FoldReshapeWithSplat<CollapseShapeOp>,
1870  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1871  context);
1872 }
1873 
1874 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1875  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
1876  adaptor.getOperands());
1877 }
1878 
1879 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1880  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
1881  adaptor.getOperands());
1882 }
1883 
1884 //===----------------------------------------------------------------------===//
1885 // ExtractSliceOp
1886 //===----------------------------------------------------------------------===//
1887 
1888 void ExtractSliceOp::getAsmResultNames(
1889  function_ref<void(Value, StringRef)> setNameFn) {
1890  setNameFn(getResult(), "extracted_slice");
1891 }
1892 
1893 /// An extract_slice result type can be inferred, when it is not
1894 /// rank-reduced, from the source type and the static representation of
1895 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
1896 RankedTensorType ExtractSliceOp::inferResultType(
1897  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
1898  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
1899  // An extract_slice op may specify only a leading subset of offset/sizes/
1900  // strides in which case we complete with offset=0, sizes from memref type
1901  // and strides=1.
1902  assert(static_cast<int64_t>(staticSizes.size()) ==
1903  sourceTensorType.getRank() &&
1904  "unexpected staticSizes not equal to rank of source");
1905  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
1906 }
1907 
1908 RankedTensorType ExtractSliceOp::inferResultType(
1909  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
1911  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1912  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1913  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1914  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1915  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1916  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
1917  staticSizes, staticStrides);
1918 }
1919 
1920 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
1921 /// number of sizes), drop as many size 1 as needed to produce an inferred
1922 /// type with the desired rank.
1923 ///
1924 /// Note that there may be multiple ways to compute this rank-reduced type:
1925 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
1926 ///
1927 /// To disambiguate, this function always drops the first 1 sizes occurrences.
1928 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1929  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1930  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1931  ArrayRef<int64_t> strides) {
1932  // Type inferred in the absence of rank-reducing behavior.
1933  auto inferredType = llvm::cast<RankedTensorType>(
1934  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
1935  int rankDiff = inferredType.getRank() - desiredResultRank;
1936  if (rankDiff > 0) {
1937  auto shape = inferredType.getShape();
1938  llvm::SmallBitVector dimsToProject =
1939  getPositionsOfShapeOne(rankDiff, shape);
1940  SmallVector<int64_t> projectedShape;
1941  // Best effort rank-reducing: drop 1s in order.
1942  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1943  if (!dimsToProject.test(pos))
1944  projectedShape.push_back(shape[pos]);
1945  inferredType =
1946  RankedTensorType::get(projectedShape, inferredType.getElementType());
1947  }
1948  return inferredType;
1949 }
1950 
1951 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1952  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1954  ArrayRef<OpFoldResult> strides) {
1955  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1956  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1957  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1958  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1959  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1960  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1961  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1962  staticStrides);
1963 }
1964 
1965 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1966 /// result type. If the type passed is nullptr, it is inferred.
1967 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1968  RankedTensorType resultType, Value source,
1969  ArrayRef<OpFoldResult> offsets,
1970  ArrayRef<OpFoldResult> sizes,
1971  ArrayRef<OpFoldResult> strides,
1972  ArrayRef<NamedAttribute> attrs) {
1973  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1974  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1975  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1976  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1977  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1978  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
1979  // Structuring implementation this way avoids duplication between builders.
1980  if (!resultType) {
1981  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
1982  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
1983  }
1984  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1985  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1986  b.getDenseI64ArrayAttr(staticSizes),
1987  b.getDenseI64ArrayAttr(staticStrides));
1988  result.addAttributes(attrs);
1989 }
1990 
1991 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1992 /// result type.
1993 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1994  ArrayRef<OpFoldResult> offsets,
1995  ArrayRef<OpFoldResult> sizes,
1996  ArrayRef<OpFoldResult> strides,
1997  ArrayRef<NamedAttribute> attrs) {
1998  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1999 }
2000 
2001 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2002 /// a Range vector.
2003 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2004  ArrayRef<Range> ranges,
2005  ArrayRef<NamedAttribute> attrs) {
2006  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2007  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2008 }
2009 
2010 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2011 /// the type passed is nullptr, it is inferred.
2012 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2013  RankedTensorType resultType, Value source,
2014  ValueRange offsets, ValueRange sizes,
2015  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2016  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2017  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2018  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2019  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2020  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2021  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2022  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2023 }
2024 
2025 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2026 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2027  ValueRange offsets, ValueRange sizes,
2028  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2029  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2030 }
2031 
2033  Operation *op,
2034  RankedTensorType expectedType) {
2035  switch (result) {
2037  return success();
2039  return op->emitError("expected rank to be smaller or equal to ")
2040  << "the other rank. ";
2042  return op->emitError("expected type to be ")
2043  << expectedType << " or a rank-reduced version. (size mismatch) ";
2045  return op->emitError("expected element type to be ")
2046  << expectedType.getElementType();
2047  default:
2048  llvm_unreachable("unexpected extract_slice op verification result");
2049  }
2050 }
2051 
2052 /// Verifier for ExtractSliceOp.
2054  // Verify result type against inferred type.
2055  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2056  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2057  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2058  return produceSliceErrorMsg(result, *this, expectedType);
2059 }
2060 
2061 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2063 }
2064 
2066 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2067  ArrayRef<int64_t> desiredShape) {
2068  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2069  assert(sourceTensorType && "not a ranked tensor type");
2070  auto sourceShape = sourceTensorType.getShape();
2071  if (sourceShape.equals(desiredShape))
2072  return value;
2073  auto maybeRankReductionMask =
2074  mlir::computeRankReductionMask(sourceShape, desiredShape);
2075  if (!maybeRankReductionMask)
2076  return failure();
2078  b, loc, value,
2079  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2080 }
2081 
2083  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2084  reifiedReturnShapes.resize(1);
2085  reifiedReturnShapes[0].reserve(getType().getRank());
2087  llvm::SmallBitVector droppedDims = getDroppedDims();
2088  for (const auto &size : enumerate(mixedSizes)) {
2089  if (droppedDims.test(size.index()))
2090  continue;
2091  reifiedReturnShapes[0].push_back(size.value());
2092  }
2093  return success();
2094 }
2095 
2096 namespace {
2097 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2098 /// This essentially pushes memref_cast past its consuming slice when
2099 /// `canFoldIntoConsumerOp` is true.
2100 ///
2101 /// Example:
2102 /// ```
2103 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2104 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2105 /// tensor<3x4xf32>
2106 /// ```
2107 /// is rewritten into:
2108 /// ```
2109 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2110 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2111 /// ```
2112 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2113 public:
2115 
2116  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2117  PatternRewriter &rewriter) const override {
2118  // Any constant operand, just return to let the constant folder kick in.
2119  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2120  return matchPattern(operand, matchConstantIndex());
2121  }))
2122  return failure();
2123 
2124  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2125  if (!castOp)
2126  return failure();
2127 
2128  if (!canFoldIntoConsumerOp(castOp))
2129  return failure();
2130 
2131  // Create folded extract.
2132  Location loc = sliceOp.getLoc();
2133  Value newResult = rewriter.create<ExtractSliceOp>(
2134  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2135  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2136  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2137  if (newResult.getType() != sliceOp.getType())
2138  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2139  rewriter.replaceOp(sliceOp, newResult);
2140  return success();
2141  }
2142 };
2143 
2144 /// Slice elements from `values` into `outValues`. `counts` represents the
2145 /// numbers of elements to stride in the original values for each dimension.
2146 /// The output values can be used to construct a DenseElementsAttr.
2147 template <typename IterTy, typename ElemTy>
2148 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2149  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2150  ArrayRef<int64_t> strides,
2151  llvm::SmallVectorImpl<ElemTy> *outValues) {
2152  assert(offsets.size() == sizes.size());
2153  assert(offsets.size() == strides.size());
2154  if (offsets.empty())
2155  return;
2156 
2157  int64_t offset = offsets.front();
2158  int64_t size = sizes.front();
2159  int64_t stride = strides.front();
2160  if (offsets.size() == 1) {
2161  for (int64_t i = 0; i < size; ++i, offset += stride)
2162  outValues->push_back(*(values + offset));
2163 
2164  return;
2165  }
2166 
2167  for (int64_t i = 0; i < size; ++i, offset += stride) {
2168  auto begin = values + offset * counts.front();
2169  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2170  offsets.drop_front(), sizes.drop_front(),
2171  strides.drop_front(), outValues);
2172  }
2173 }
2174 
2175 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2176 /// folded operation might introduce more constant data; Users can control
2177 /// their heuristics by the control function.
2178 class ConstantOpExtractSliceFolder final
2179  : public OpRewritePattern<ExtractSliceOp> {
2180 public:
2182 
2183  ConstantOpExtractSliceFolder(MLIRContext *context,
2185  : OpRewritePattern<ExtractSliceOp>(context),
2186  controlFn(std::move(controlFn)) {}
2187 
2188  LogicalResult matchAndRewrite(ExtractSliceOp op,
2189  PatternRewriter &rewriter) const override {
2190  DenseElementsAttr attr;
2191  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2192  return failure();
2193 
2194  // A constant splat is handled by fold().
2195  if (attr.isSplat())
2196  return failure();
2197 
2198  // Dynamic result shape is not supported.
2199  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2200  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2201  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2202  return failure();
2203 
2204  // Customized control over the folding.
2205  if (!controlFn(op))
2206  return failure();
2207 
2208  int64_t count = sourceType.getNumElements();
2209  if (count == 0)
2210  return failure();
2211 
2212  // Check if there are any dynamic parts, which are not supported.
2213  auto offsets = op.getStaticOffsets();
2214  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2215  return failure();
2216  auto sizes = op.getStaticSizes();
2217  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2218  return failure();
2219  auto strides = op.getStaticStrides();
2220  if (llvm::is_contained(strides, ShapedType::kDynamic))
2221  return failure();
2222 
2223  // Compute the stride for each dimension.
2224  SmallVector<int64_t> counts;
2225  ArrayRef<int64_t> shape = sourceType.getShape();
2226  counts.reserve(shape.size());
2227  for (int64_t v : shape) {
2228  count = count / v;
2229  counts.push_back(count);
2230  }
2231 
2232  // New attribute constructed by the sliced values.
2233  DenseElementsAttr newAttr;
2234 
2235  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2236  SmallVector<APInt> outValues;
2237  outValues.reserve(sourceType.getNumElements());
2238  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2239  elems.begin(), counts, offsets, sizes, strides, &outValues);
2240  newAttr = DenseElementsAttr::get(resultType, outValues);
2241  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2242  SmallVector<APFloat> outValues;
2243  outValues.reserve(sourceType.getNumElements());
2244  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2245  elems.begin(), counts, offsets, sizes, strides, &outValues);
2246  newAttr = DenseElementsAttr::get(resultType, outValues);
2247  }
2248 
2249  if (newAttr) {
2250  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2251  return success();
2252  }
2253 
2254  return failure();
2255  }
2256 
2257 private:
2258  /// This additionally controls whether the fold happens or not. Users can
2259  /// impose their heuristics in the function.
2261 };
2262 
2263 } // namespace
2264 
2266  RewritePatternSet &patterns,
2267  const ControlConstantExtractSliceFusionFn &controlFn) {
2268  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2269 }
2270 
2271 /// Return the canonical type of the result of an extract_slice op.
2273  RankedTensorType operator()(ExtractSliceOp op,
2274  ArrayRef<OpFoldResult> mixedOffsets,
2275  ArrayRef<OpFoldResult> mixedSizes,
2276  ArrayRef<OpFoldResult> mixedStrides) {
2277  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2278  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2279  mixedStrides);
2280  }
2281 };
2282 
2283 /// A canonicalizer wrapper to replace ExtractSliceOps.
2285  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2286  ExtractSliceOp newOp) {
2287  Value replacement = newOp.getResult();
2288  if (replacement.getType() != op.getType())
2289  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2290  replacement);
2291  rewriter.replaceOp(op, replacement);
2292  }
2293 };
2294 
2295 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2296  MLIRContext *context) {
2297  results.add<
2300  ExtractSliceOpCastFolder>(context);
2301 }
2302 
2303 //
2304 static LogicalResult
2305 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2306  ShapedType shapedType) {
2307  OpBuilder b(op.getContext());
2308  for (OpFoldResult ofr : op.getMixedOffsets())
2309  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2310  return failure();
2311  // Rank-reducing noops only need to inspect the leading dimensions:
2312  // llvm::zip is appropriate.
2313  auto shape = shapedType.getShape();
2314  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2315  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2316  return failure();
2317  for (OpFoldResult ofr : op.getMixedStrides())
2318  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2319  return failure();
2320  return success();
2321 }
2322 
2323 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2324 /// slice, we can return the InsertSliceOp's source directly.
2325 // TODO: This only checks the immediate producer; extend to go up the
2326 // insert/extract chain if the slices are disjoint.
2327 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2328  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2329 
2330  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2331  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2332  insertOp.isSameAs(extractOp, isSame))
2333  return insertOp.getSource();
2334 
2335  return {};
2336 }
2337 
2338 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2339  if (OpFoldResult reshapedSource = reshapeConstantSource(
2340  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2341  getResult().getType()))
2342  return reshapedSource;
2343  if (getSourceType() == getType() &&
2345  return this->getSource();
2346  if (Value slice = foldExtractAfterInsertSlice(*this))
2347  return slice;
2348 
2349  return OpFoldResult();
2350 }
2351 
2353  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2354  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2355  unsigned rank = rankedTensorType.getRank();
2356  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2357  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2358  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2359  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2360  offsets, sizes, strides);
2361 }
2362 
2363 //===----------------------------------------------------------------------===//
2364 // InsertSliceOp
2365 //===----------------------------------------------------------------------===//
2366 
2367 void InsertSliceOp::getAsmResultNames(
2368  function_ref<void(Value, StringRef)> setNameFn) {
2369  setNameFn(getResult(), "inserted_slice");
2370 }
2371 
2372 // Build a InsertSliceOp with mixed static and dynamic entries.
2373 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2374  Value dest, ArrayRef<OpFoldResult> offsets,
2375  ArrayRef<OpFoldResult> sizes,
2376  ArrayRef<OpFoldResult> strides,
2377  ArrayRef<NamedAttribute> attrs) {
2378  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2379  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2380  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2381  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2382  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2383  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2384  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2385  b.getDenseI64ArrayAttr(staticSizes),
2386  b.getDenseI64ArrayAttr(staticStrides));
2387  result.addAttributes(attrs);
2388 }
2389 
2390 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2391 /// Range vector.
2392 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2393  Value dest, ArrayRef<Range> ranges,
2394  ArrayRef<NamedAttribute> attrs) {
2395  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2396  build(b, result, source, dest, offsets, sizes, strides, attrs);
2397 }
2398 
2399 // Build a InsertSliceOp with dynamic entries.
2400 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2401  Value dest, ValueRange offsets, ValueRange sizes,
2402  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2403  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2404  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2405  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2406  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2407  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2408  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2409  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2410 }
2411 
2412 /// Rank-reducing type verification for both InsertSliceOp and
2413 /// ParallelInsertSliceOp.
2415  RankedTensorType srcType, RankedTensorType dstType,
2416  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2417  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2418  // insert_slice is the inverse of extract_slice, use the same type
2419  // inference.
2420  RankedTensorType expected = ExtractSliceOp::inferResultType(
2421  dstType, staticOffsets, staticSizes, staticStrides);
2422  if (expectedType)
2423  *expectedType = expected;
2424  return isRankReducedType(expected, srcType);
2425 }
2426 
2427 /// Verifier for InsertSliceOp.
2429  RankedTensorType expectedType;
2430  SliceVerificationResult result =
2431  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2432  getStaticSizes(), getStaticStrides(), &expectedType);
2433  return produceSliceErrorMsg(result, *this, expectedType);
2434 }
2435 
2436 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2437 /// can mutate the second InsertSliceOp's destination to the first one's.
2438 ///
2439 /// Example:
2440 ///
2441 /// ```mlir
2442 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2443 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2444 /// ```
2445 ///
2446 /// folds into:
2447 ///
2448 /// ```mlir
2449 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2450 /// ```
2451 ///
2452 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2453 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2454  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2455 
2456  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2457  if (!prevInsertOp ||
2458  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2459  !prevInsertOp.isSameAs(insertOp, isSame))
2460  return failure();
2461 
2462  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2463  return success();
2464 }
2465 
2466 /// Folds round-trip extract/insert slice op pairs.
2467 /// Example:
2468 /// ```mlir
2469 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2470 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2471 /// ```
2472 /// can be folded into %val.
2473 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2474  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2475 
2476  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2477  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2478  !extractOp.isSameAs(insertOp, isSame))
2479  return nullptr;
2480 
2481  return extractOp.getSource();
2482 }
2483 
2484 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2485  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2486  getSourceType() == getType() &&
2488  return this->getSource();
2490  return getResult();
2491  if (auto result = foldInsertAfterExtractSlice(*this))
2492  return result;
2493  return OpFoldResult();
2494 }
2495 
2497  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2498  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2499  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2500  return success();
2501 }
2502 
2503 namespace {
2504 /// Pattern to rewrite a insert_slice op with constant arguments.
2505 ///
2506 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2507 template <typename InsertOpTy>
2508 class InsertSliceOpConstantArgumentFolder final
2509  : public OpRewritePattern<InsertOpTy> {
2510 public:
2512 
2513  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2514  PatternRewriter &rewriter) const override {
2515  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2516  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2517  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2518 
2519  // No constant operands were folded, just return;
2520  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2521  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2522  failed(foldDynamicStrideList(mixedStrides)))
2523  return failure();
2524 
2525  // Create the new op in canonical form.
2526  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2527  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2528  mixedOffsets, mixedSizes, mixedStrides);
2529  Value toInsert = insertSliceOp.getSource();
2530  if (sourceType != insertSliceOp.getSourceType()) {
2531  OpBuilder::InsertionGuard g(rewriter);
2532  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2533  // is that the insertion point is just before the ParallelCombiningOp in
2534  // the parallel case.
2535  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2536  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2537  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2538  sourceType, toInsert);
2539  }
2540  rewriter.replaceOpWithNewOp<InsertOpTy>(
2541  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2542  mixedSizes, mixedStrides);
2543  return success();
2544  }
2545 };
2546 
2547 /// Fold tensor_casts with insert_slice operations. If the source or
2548 /// destination tensor is a tensor_cast that removes static type information,
2549 /// the cast is folded into the insert_slice operation. E.g.:
2550 ///
2551 /// ```mlir
2552 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2553 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2554 /// ```
2555 ///
2556 /// folds into:
2557 ///
2558 /// ```mlir
2559 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2560 /// ```
2561 ///
2562 /// Note: When folding a cast on the destination tensor, the result of the
2563 /// insert_slice operation is casted to ensure that the type of the result did
2564 /// not change.
2565 ///
2566 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2567 template <typename InsertOpTy>
2568 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2570 
2571  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2572  PatternRewriter &rewriter) const override {
2573  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2574  return matchPattern(operand, matchConstantIndex());
2575  }))
2576  return failure();
2577 
2578  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2579  auto castOp = v.getDefiningOp<tensor::CastOp>();
2580  if (!castOp || !canFoldIntoConsumerOp(castOp))
2581  return std::nullopt;
2582  return castOp.getSource();
2583  };
2584  std::optional<Value> sourceCastSource =
2585  getSourceOfCastOp(insertSliceOp.getSource());
2586  std::optional<Value> destCastSource =
2587  getSourceOfCastOp(insertSliceOp.getDest());
2588  if (!sourceCastSource && !destCastSource)
2589  return failure();
2590 
2591  auto src =
2592  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2593  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2594  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2595  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2596  if (!srcType || !dstType)
2597  return failure();
2598  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2599  insertSliceOp.getStaticSizes(),
2600  insertSliceOp.getStaticStrides()) !=
2602  return failure();
2603 
2604  Operation *replacement = rewriter.create<InsertOpTy>(
2605  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2606  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2607 
2608  // In the parallel case there is no result and so nothing to cast.
2609  bool isParallelInsert =
2610  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2611  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2612  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2613  insertSliceOp.getDestType(),
2614  replacement->getResult(0));
2615  }
2616  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2617  return success();
2618  }
2619 };
2620 
2621 /// If additional static type information can be deduced from a insert_slice's
2622 /// size operands, insert an explicit cast of the op's source operand. This
2623 /// enables other canonicalization patterns that are matching for tensor_cast
2624 /// ops such as `ForOpTensorCastFolder` in SCF.
2625 ///
2626 /// Example:
2627 ///
2628 /// ```mlir
2629 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2630 /// : tensor<?x?xf32> into ...
2631 /// ```
2632 ///
2633 /// folds into:
2634 ///
2635 /// ```mlir
2636 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2637 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2638 /// : tensor<64x64xf32> into ...
2639 /// ```
2640 ///
2641 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2642 template <typename InsertOpTy>
2643 struct InsertSliceOpSourceCastInserter final
2644  : public OpRewritePattern<InsertOpTy> {
2646 
2647  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2648  PatternRewriter &rewriter) const override {
2649  RankedTensorType srcType = insertSliceOp.getSourceType();
2650  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2651  return failure();
2652  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2653  srcType.getShape().end());
2654  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2655  if (std::optional<int64_t> constInt =
2656  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2657  // Bail on invalid IR.
2658  if (*constInt < 0)
2659  return failure();
2660  newSrcShape[i] = *constInt;
2661  }
2662  }
2663  if (!hasValidSizesOffsets(newSrcShape))
2664  return failure();
2665 
2666  RankedTensorType newSrcType = RankedTensorType::get(
2667  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2668  if (srcType == newSrcType ||
2669  !preservesStaticInformation(srcType, newSrcType) ||
2670  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2671  return failure();
2672 
2673  // newSrcType is:
2674  // 1) Different from srcType.
2675  // 2) "More static" than srcType.
2676  // 3) Cast-compatible with srcType.
2677  // Insert the cast.
2678  OpBuilder::InsertionGuard g(rewriter);
2679  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2680  // that the insertion point is just before the ParallelCombiningOp in the
2681  // parallel case.
2682  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2683  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2684  Value cast = rewriter.create<tensor::CastOp>(
2685  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2686  rewriter.replaceOpWithNewOp<InsertOpTy>(
2687  insertSliceOp, cast, insertSliceOp.getDest(),
2688  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2689  insertSliceOp.getMixedStrides());
2690  return success();
2691  }
2692 };
2693 } // namespace
2694 
2695 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2696  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2697 }
2698 
2699 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2700  MLIRContext *context) {
2701  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2702  InsertSliceOpCastFolder<InsertSliceOp>,
2703  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2704 }
2705 
2707  Location loc,
2708  Value tensor,
2709  Value dest) {
2710  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2711  unsigned rank = rankedTensorType.getRank();
2712  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2713  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2714  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2715  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2716  sizes, strides);
2717 }
2718 
2719 //===----------------------------------------------------------------------===//
2720 // PadOp
2721 //===----------------------------------------------------------------------===//
2722 
2723 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2724  setNameFn(getResult(), "padded");
2725 }
2726 
2727 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2728 // supports optional types.
2729 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2730  Type typeToInfer, Type typeToInferFrom) {}
2731 
2734  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2735  Type &typeToInfer, Type typeToInferFrom) {
2736  if (optOperand)
2737  typeToInfer = typeToInferFrom;
2738  return success();
2739 }
2740 
2742  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2743  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2744  auto expectedType =
2745  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2746  if (!expectedType) {
2747  return emitError("failed to infer expectedType from sourceType ")
2748  << sourceType << ", specified resultType is " << resultType;
2749  }
2750  if (resultType.getRank() != expectedType.getRank()) {
2751  return emitError("specified type ")
2752  << resultType << " does not match the inferred type "
2753  << expectedType;
2754  }
2755  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2756  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2757  continue;
2758  if (expectedType.isDynamicDim(i))
2759  continue;
2760  return emitError("specified type ")
2761  << resultType << " does not match the inferred type "
2762  << expectedType;
2763  }
2764 
2765  return success();
2766 }
2767 
2768 LogicalResult PadOp::verifyRegions() {
2769  auto &region = getRegion();
2770  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2771  Block &block = region.front();
2772  if (block.getNumArguments() != rank)
2773  return emitError("expected the block to have ") << rank << " arguments";
2774 
2775  // Note: the number and type of yield values are checked in the YieldOp.
2776  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2777  if (!en.value().isIndex())
2778  return emitOpError("expected block argument ")
2779  << (en.index() + 1) << " to be an index";
2780  }
2781 
2782  // Ensure that the region yields an element of the right type.
2783  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2784  if (yieldOp.getValue().getType() !=
2785  llvm::cast<ShapedType>(getType()).getElementType())
2786  return emitOpError("expected yield type to match shape element type");
2787 
2788  return success();
2789 }
2790 
2791 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2792  ArrayRef<int64_t> staticLow,
2793  ArrayRef<int64_t> staticHigh,
2794  ArrayRef<int64_t> resultShape) {
2795  unsigned rank = sourceType.getRank();
2796  if (staticLow.size() != rank)
2797  return RankedTensorType();
2798  if (staticHigh.size() != rank)
2799  return RankedTensorType();
2800  if (!resultShape.empty() && resultShape.size() != rank)
2801  return RankedTensorType();
2802 
2803  SmallVector<int64_t, 4> inferredShape;
2804  for (auto i : llvm::seq<unsigned>(0, rank)) {
2805  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2806  staticHigh[i] == ShapedType::kDynamic) {
2807  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2808  : resultShape[i]);
2809  } else {
2810  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2811  assert((resultShape.empty() || size == resultShape[i] ||
2812  resultShape[i] == ShapedType::kDynamic) &&
2813  "mismatch between inferred shape and result shape");
2814  inferredShape.push_back(size);
2815  }
2816  }
2817 
2818  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2819 }
2820 
2821 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2822  Value source, ArrayRef<int64_t> staticLow,
2823  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2824  bool nofold, ArrayRef<NamedAttribute> attrs) {
2825  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2826  if (!resultType)
2827  resultType = inferResultType(sourceType, staticLow, staticHigh);
2828  build(b, result, resultType, source, low, high,
2829  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2830  nofold ? b.getUnitAttr() : UnitAttr());
2831  result.addAttributes(attrs);
2832 }
2833 
2834 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2835  Value source, ValueRange low, ValueRange high, bool nofold,
2836  ArrayRef<NamedAttribute> attrs) {
2837  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2838  unsigned rank = sourceType.getRank();
2839  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2840  build(b, result, resultType, source, staticVector, staticVector, low, high,
2841  nofold, attrs);
2842 }
2843 
2844 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2845  Value source, ArrayRef<OpFoldResult> low,
2846  ArrayRef<OpFoldResult> high, bool nofold,
2847  ArrayRef<NamedAttribute> attrs) {
2848  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2849  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2850  SmallVector<int64_t, 4> staticLow, staticHigh;
2851  // staticLow and staticHigh have full information of the padding config.
2852  // This will grow staticLow and staticHigh with 1 value. If the config is
2853  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
2854  // value as well.
2855  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
2856  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
2857  if (!resultType) {
2858  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2859  }
2860  assert(llvm::isa<RankedTensorType>(resultType));
2861  build(b, result, resultType, source, dynamicLow, dynamicHigh,
2862  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2863  nofold ? b.getUnitAttr() : UnitAttr());
2864  result.addAttributes(attrs);
2865 }
2866 
2867 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2868  Value source, ArrayRef<OpFoldResult> low,
2869  ArrayRef<OpFoldResult> high, Value constantPadValue,
2870  bool nofold, ArrayRef<NamedAttribute> attrs) {
2871  build(b, result, resultType, source, low, high, nofold, attrs);
2872 
2873  // Add a region and a block to yield the pad value.
2874  Region *region = result.regions[0].get();
2875  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
2876  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
2877  SmallVector<Location> blockArgLocs(sourceRank, result.location);
2878 
2879  // `builder.createBlock` changes the insertion point within the block. Create
2880  // a guard to reset the insertion point of the builder after it is destroyed.
2881  OpBuilder::InsertionGuard guard(b);
2882  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
2883  b.create<tensor::YieldOp>(result.location, constantPadValue);
2884 }
2885 
2886 llvm::SmallBitVector PadOp::getPaddedDims() {
2887  llvm::SmallBitVector paddedDims(getSourceType().getRank());
2888  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
2889  for (const auto &en : enumerate(paddingWidths))
2890  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
2891  paddedDims.set(en.index());
2892  };
2893  extractPaddedDims(getMixedLowPad());
2894  extractPaddedDims(getMixedHighPad());
2895  return paddedDims;
2896 }
2897 
2898 namespace {
2899 // Folds tensor.pad when padding is static zeros and the attribute
2900 // doesn't request otherwise.
2901 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
2903 
2904  LogicalResult matchAndRewrite(PadOp padTensorOp,
2905  PatternRewriter &rewriter) const override {
2906  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2907  return failure();
2908  if (padTensorOp.getNofold())
2909  return failure();
2910  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2911  padTensorOp, padTensorOp.getResult().getType(),
2912  padTensorOp.getSource());
2913  return success();
2914  }
2915 };
2916 
2917 // Fold CastOp into PadOp when adding static information.
2918 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
2920 
2921  LogicalResult matchAndRewrite(PadOp padTensorOp,
2922  PatternRewriter &rewriter) const override {
2923  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2924  if (!tensor::canFoldIntoConsumerOp(castOp))
2925  return failure();
2926 
2927  auto newResultType = PadOp::inferResultType(
2928  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
2929  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2930  padTensorOp.getResultType().getShape());
2931 
2932  if (newResultType == padTensorOp.getResultType()) {
2933  rewriter.modifyOpInPlace(padTensorOp, [&]() {
2934  padTensorOp.getSourceMutable().assign(castOp.getSource());
2935  });
2936  } else {
2937  auto newOp = rewriter.create<PadOp>(
2938  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2939  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2940  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2941  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2942  IRMapping mapper;
2943  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2944 
2945  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2946  padTensorOp, padTensorOp.getResultType(), newOp);
2947  }
2948  return success();
2949  }
2950 };
2951 
2952 // Fold CastOp using the result of PadOp back into the latter if it adds
2953 // static information.
2954 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2956 
2957  LogicalResult matchAndRewrite(PadOp padTensorOp,
2958  PatternRewriter &rewriter) const override {
2959  if (!padTensorOp.getResult().hasOneUse())
2960  return failure();
2961  auto tensorCastOp =
2962  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2963  if (!tensorCastOp)
2964  return failure();
2965  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2966  tensorCastOp.getDest().getType()))
2967  return failure();
2968 
2969  auto replacementOp = rewriter.create<PadOp>(
2970  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2971  padTensorOp.getSource(), padTensorOp.getStaticLow(),
2972  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2973  padTensorOp.getHigh(), padTensorOp.getNofold(),
2974  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2975  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2976 
2977  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2978  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
2979  return success();
2980  }
2981 };
2982 
2983 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
2984 /// different dimensions. The pattern applies if the following preconditions
2985 /// hold:
2986 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
2987 /// 2) the tensor::ExtractSliceOps have only unit-strides,
2988 /// 3) the tensor::PadOps perform only high-padding,
2989 /// 4) the tensor::PadOps have the same constant padding value,
2990 /// 5) the tensor::PadOps do not have common padding dimensions,
2991 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
2992 /// zero-offset for every dimension.
2993 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
2994 /// the
2995 /// padded source dimensions.
2996 ///
2997 /// Example:
2998 ///
2999 /// ```mlir
3000 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3001 /// : tensor<64x64xf32> to tensor<?x64xf32>
3002 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3003 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3004 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3005 /// : tensor<8x64xf32> to tensor<8x?xf32>
3006 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3007 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3008 /// ```
3009 ///
3010 /// folds into:
3011 ///
3012 /// ```mlir
3013 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3014 /// : tensor<64x64xf32> to tensor<?x?xf32>
3015 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3016 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3017 /// ```
3018 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3020 
3021  LogicalResult matchAndRewrite(PadOp padOp,
3022  PatternRewriter &rewriter) const override {
3023  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3024  if (!innerSliceOp)
3025  return failure();
3026  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3027  if (!outerPadOp || outerPadOp.getNofold())
3028  return failure();
3029  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3030  if (!outerSliceOp)
3031  return failure();
3032 
3033  // 1) Fail if the chain is rank-reducing.
3034  int64_t rank = padOp.getSourceType().getRank();
3035  if (outerSliceOp.getSourceType().getRank() != rank) {
3036  return rewriter.notifyMatchFailure(padOp,
3037  "cannot fold rank-reducing chain");
3038  }
3039 
3040  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3041  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3042  return rewriter.notifyMatchFailure(
3043  padOp, "cannot fold non-unit stride ExtractSliceOps");
3044  }
3045 
3046  // 3) Fail if the tensor::PadOps have non-zero low padding.
3047  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3048  return rewriter.notifyMatchFailure(padOp,
3049  "cannot fold PadOps with low padding");
3050  }
3051 
3052  // 4) Fail if the tensor::PadOps padding values do not match.
3053  Attribute innerAttr, outerAttr;
3054  Value innerValue = padOp.getConstantPaddingValue();
3055  Value outerValue = outerPadOp.getConstantPaddingValue();
3056  if (!innerValue || !outerValue ||
3057  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3058  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3059  innerAttr != outerAttr) {
3060  return rewriter.notifyMatchFailure(
3061  padOp, "cannot fold PadOps with different padding values");
3062  }
3063 
3064  // 5) Fail if a dimension is padded by both tensor::PadOps.
3065  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3066  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3067  if (innerDims.anyCommon(outerDims)) {
3068  return rewriter.notifyMatchFailure(
3069  padOp, "cannot fold PadOps with common padding dimensions");
3070  }
3071 
3072  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3073  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3074  // for every dimension, and use the offset the other pair. Fail if no
3075  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3076  // exists.
3077  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3078  for (auto en : enumerate(newOffsets)) {
3079  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3080  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3081  if (!innerDims.test(en.index()) &&
3082  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3083  en.value() = outerOffset;
3084  continue;
3085  }
3086  if (!outerDims.test(en.index()) &&
3087  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3088  en.value() = innerOffset;
3089  continue;
3090  }
3091  return rewriter.notifyMatchFailure(
3092  padOp, "cannot find zero-offset and zero-padding pair");
3093  }
3094 
3095  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3096  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3097  // outer tensor::PadOp and fail if the size of the inner
3098  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3099  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3100  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3101  for (auto en : enumerate(newSizes)) {
3102  if (!outerDims.test(en.index()))
3103  continue;
3104  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3105  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3106  assert(!ShapedType::isDynamic(sourceSize) &&
3107  "expected padded dimension to have a static size");
3108  if (getConstantIntValue(sliceSize) != sourceSize) {
3109  return rewriter.notifyMatchFailure(
3110  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3111  "match the size of the outer padding");
3112  }
3113  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3114  }
3115 
3116  // Combine the high paddings of the two tensor::PadOps.
3117  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3118  for (auto en : enumerate(newHighPad)) {
3119  if (innerDims.test(en.index()))
3120  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3121  if (outerDims.test(en.index()))
3122  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3123  }
3124 
3125  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3126  // the two paddings in one step.
3127  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3128  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3129  innerSliceOp.getMixedStrides());
3130  auto newPadOp = rewriter.create<PadOp>(
3131  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3132  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3133  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3134  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3135  newPadOp.getRegion().begin());
3136  rewriter.replaceOp(padOp, newPadOp.getResult());
3137  return success();
3138  }
3139 };
3140 
3141 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3143 
3144  LogicalResult matchAndRewrite(PadOp padTensorOp,
3145  PatternRewriter &rewriter) const override {
3146  Value input = padTensorOp.getSource();
3147  if (!llvm::isa<RankedTensorType>(input.getType()))
3148  return failure();
3149  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3150  auto inputRank = inputDims.size();
3151 
3152  auto oldResultType =
3153  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3154  if (!oldResultType)
3155  return failure();
3156 
3157  auto outputDims = oldResultType.getShape();
3158 
3159  // Extract the static info from the high and low operands.
3160  SmallVector<int64_t> constOperandsLow;
3161  SmallVector<Value> newLows;
3162  for (auto operand : padTensorOp.getLow()) {
3163  APSInt intOp;
3164  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3165  constOperandsLow.push_back(ShapedType::kDynamic);
3166  newLows.push_back(operand);
3167  continue;
3168  }
3169  constOperandsLow.push_back(intOp.getExtValue());
3170  }
3171  SmallVector<int64_t> constOperandsHigh;
3172  SmallVector<Value> newHighs;
3173  for (auto operand : padTensorOp.getHigh()) {
3174  APSInt intOp;
3175  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3176  constOperandsHigh.push_back(ShapedType::kDynamic);
3177  newHighs.push_back(operand);
3178  continue;
3179  }
3180  constOperandsHigh.push_back(intOp.getExtValue());
3181  }
3182 
3183  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3184  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3185 
3186  // Verify the op is well-formed.
3187  if (inputDims.size() != outputDims.size() ||
3188  inputDims.size() != constLow.size() ||
3189  inputDims.size() != constHigh.size())
3190  return failure();
3191 
3192  auto lowCount = 0;
3193  auto highCount = 0;
3194  for (size_t i = 0; i < inputRank; i++) {
3195  if (constLow[i] == ShapedType::kDynamic)
3196  constLow[i] = constOperandsLow[lowCount++];
3197  if (constHigh[i] == ShapedType::kDynamic)
3198  constHigh[i] = constOperandsHigh[highCount++];
3199  }
3200 
3201  auto staticLow = ArrayRef<int64_t>(constLow);
3202  auto staticHigh = ArrayRef<int64_t>(constHigh);
3203 
3204  // Calculate the output sizes with the static information.
3205  SmallVector<int64_t> newOutDims;
3206  for (size_t i = 0; i < inputRank; i++) {
3207  if (outputDims[i] == ShapedType::kDynamic) {
3208  newOutDims.push_back(
3209  (staticLow[i] == ShapedType::kDynamic ||
3210  staticHigh[i] == ShapedType::kDynamic ||
3211  inputDims[i] == ShapedType::kDynamic
3212  ? ShapedType::kDynamic
3213  : inputDims[i] + staticLow[i] + staticHigh[i]));
3214  } else {
3215  newOutDims.push_back(outputDims[i]);
3216  }
3217  }
3218 
3219  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3220  llvm::all_of(newOutDims,
3221  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3222  return failure();
3223 
3224  // Rewrite the op using the new static type.
3225  auto newResultType = RankedTensorType::get(
3226  newOutDims, padTensorOp.getType().getElementType());
3227  auto newOp = rewriter.create<PadOp>(
3228  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3229  newLows, newHighs, padTensorOp.getNofold(),
3230  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3231 
3232  IRMapping mapper;
3233  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3234  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3235  newOp);
3236 
3237  return success();
3238  }
3239 };
3240 
3241 } // namespace
3242 
3243 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3244  MLIRContext *context) {
3245  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3246  FoldOrthogonalPaddings, FoldStaticPadding>(context);
3247 }
3248 
3249 /// Return the padding value of the PadOp if it constant. In this context,
3250 /// "constant" means an actual constant or "defined outside of the block".
3251 ///
3252 /// Values are considered constant in three cases:
3253 /// - A ConstantLike value.
3254 /// - A basic block argument from a different block.
3255 /// - A value defined outside of the block.
3256 ///
3257 /// If the padding value is not constant, an empty Value is returned.
3258 Value PadOp::getConstantPaddingValue() {
3259  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3260  if (!yieldOp)
3261  return {};
3262  Value padValue = yieldOp.getValue();
3263  // Check if yield value is a constant.
3264  if (matchPattern(padValue, m_Constant()))
3265  return padValue;
3266  // Check if yield value is defined inside the PadOp block.
3267  if (padValue.getParentBlock() == &getRegion().front())
3268  return {};
3269  // Else: Yield value defined outside of the PadOp block.
3270  return padValue;
3271 }
3272 
3273 OpFoldResult PadOp::fold(FoldAdaptor) {
3274  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3275  !getNofold())
3276  return getSource();
3277  return {};
3278 }
3279 
3280 //===----------------------------------------------------------------------===//
3281 // ParallelInsertSliceOp
3282 //===----------------------------------------------------------------------===//
3283 
3284 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3285  ParallelCombiningOpInterface parallelCombiningParent =
3286  getParallelCombiningParent();
3287  for (const auto &it :
3288  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3289  Operation &nextOp = it.value();
3290  if (&nextOp == getOperation())
3291  return parallelCombiningParent.getParentResult(it.index());
3292  }
3293  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3294 }
3295 
3296 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3297 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3298  Value source, Value dest,
3299  ArrayRef<OpFoldResult> offsets,
3300  ArrayRef<OpFoldResult> sizes,
3301  ArrayRef<OpFoldResult> strides,
3302  ArrayRef<NamedAttribute> attrs) {
3303  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3304  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3305  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3306  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3307  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3308  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3309  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3310  b.getDenseI64ArrayAttr(staticSizes),
3311  b.getDenseI64ArrayAttr(staticStrides));
3312  result.addAttributes(attrs);
3313 }
3314 
3315 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3316 /// packed into a Range vector.
3317 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3318  Value source, Value dest,
3319  ArrayRef<Range> ranges,
3320  ArrayRef<NamedAttribute> attrs) {
3321  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3322  build(b, result, source, dest, offsets, sizes, strides, attrs);
3323 }
3324 
3325 // Build a ParallelInsertSliceOp with dynamic entries.
3326 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3327  Value source, Value dest, ValueRange offsets,
3328  ValueRange sizes, ValueRange strides,
3329  ArrayRef<NamedAttribute> attrs) {
3330  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3331  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3332  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3333  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3334  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3335  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3336  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3337 }
3338 
3340  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3341  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3342  << *(getOperation()->getParentOp());
3343 
3344  RankedTensorType expectedType;
3345  SliceVerificationResult result =
3346  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3347  getStaticSizes(), getStaticStrides(), &expectedType);
3348  return produceSliceErrorMsg(result, *this, expectedType);
3349 }
3350 
3351 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3352  RewritePatternSet &results, MLIRContext *context) {
3353  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3354  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3355  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3356 }
3357 
3358 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3359  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3360 }
3361 
3362 //===----------------------------------------------------------------------===//
3363 // ScatterOp
3364 //===----------------------------------------------------------------------===//
3365 
3366 void ScatterOp::getAsmResultNames(
3367  function_ref<void(Value, StringRef)> setNameFn) {
3368  setNameFn(getResult(), "scatter");
3369 }
3370 
3372  int64_t destRank = getDestType().getRank();
3373  ArrayRef<int64_t> scatterDims = getScatterDims();
3374  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3375  "scatter", "dest")))
3376  return failure();
3377 
3378  if (!getUnique())
3379  return emitOpError("requires 'unique' attribute to be set");
3380  // TODO: we could also check statically that there are fewer leading index
3381  // tensor dims than the dest dims. If this is not the case, the unique
3382  // attribute cannot be true.
3383 
3384  // Use the GatherOp::inferResultType on the `dest` type and verify the
3385  // expected type matches the source type.
3386  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3387  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3388  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3389  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3390  if (getSourceType() != expectedSourceType &&
3391  getSourceType() != expectedRankReducedSourceType) {
3392  return emitOpError("source type "
3393  "mismatch: "
3394  "expected ")
3395  << expectedSourceType << " or its rank-reduced variant "
3396  << expectedRankReducedSourceType << " (got: " << getSourceType()
3397  << ")";
3398  }
3399 
3400  return success();
3401 }
3402 
3403 //===----------------------------------------------------------------------===//
3404 // SplatOp
3405 //===----------------------------------------------------------------------===//
3406 
3407 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3408  Type aggregateType, ValueRange dynamicSizes) {
3409  build(builder, result, aggregateType, element, dynamicSizes);
3410 }
3411 
3412 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3413  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3414  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3415  build(builder, result, aggregateType, element, dynamicSizes);
3416 }
3417 
3418 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3419  ArrayRef<OpFoldResult> sizes) {
3420  SmallVector<int64_t> staticShape;
3421  SmallVector<Value> dynamicSizes;
3422  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3423  build(builder, result, element, staticShape, dynamicSizes);
3424 }
3425 
3426 void SplatOp::getAsmResultNames(
3427  function_ref<void(Value, StringRef)> setNameFn) {
3428  setNameFn(getResult(), "splat");
3429 }
3430 
3432  if (getType().getNumDynamicDims() !=
3433  static_cast<int64_t>(getDynamicSizes().size()))
3434  return emitOpError("incorrect number of dynamic sizes, has ")
3435  << getDynamicSizes().size() << ", expected "
3436  << getType().getNumDynamicDims();
3437  return success();
3438 }
3439 
3442  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3443  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3444  unsigned ctr = 0;
3445  for (int64_t i = 0; i < getType().getRank(); ++i) {
3446  if (getType().isDynamicDim(i)) {
3447  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3448  } else {
3449  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3450  }
3451  }
3452  return success();
3453 }
3454 
3455 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3456  auto constOperand = adaptor.getInput();
3457  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3458  return {};
3459 
3460  // Do not fold if the splat is not statically shaped
3461  if (!getType().hasStaticShape())
3462  return {};
3463 
3464  // SplatElementsAttr::get treats single value for second arg as being a
3465  // splat.
3466  return SplatElementsAttr::get(getType(), {constOperand});
3467 }
3468 
3469 //===----------------------------------------------------------------------===//
3470 // PackOp/UnPackOp Common
3471 //===----------------------------------------------------------------------===//
3472 
3473 template <typename OpTy>
3474 static LogicalResult
3476  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3477  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3478  "applies to only pack or unpack operations");
3479  int64_t destRank = op.getDestRank();
3480  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3481  reifiedReturnShapes[0] =
3482  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3483  return success();
3484 }
3485 
3486 template <typename OpTy>
3488  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3489  "applies to only pack or unpack operations");
3490  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3491  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3492  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3493  assert(tiles.size() == dimsToTile.size() &&
3494  "tiles must match indices of dimension to block");
3495  // bind the dimension `i` with the tile factor.
3496  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3497  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3498  return dimAndTileMapping;
3499 }
3500 
3501 template <typename OpTy>
3503  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3504  "applies to only pack or unpack operations");
3505  Builder builder(op);
3506  SmallVector<OpFoldResult> mixedInnerTiles;
3507  unsigned dynamicValIndex = 0;
3508  for (int64_t staticTile : op.getStaticInnerTiles()) {
3509  if (!ShapedType::isDynamic(staticTile))
3510  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3511  else
3512  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3513  }
3514  return mixedInnerTiles;
3515 }
3516 
3517 template <typename OpTy>
3519  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3520  "applies to only pack or unpack operations");
3521  SmallVector<Value> dynamicTiles;
3522  SmallVector<int64_t> staticTiles;
3523  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3524  return staticTiles;
3525 }
3526 
3527 /// Returns true if `dimsPos` is invalid. It is invalid when:
3528 /// a) It contains duplicate.
3529 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3530 /// c) The number of elements in `dimsPos` is > than `rank`.
3532  size_t rank) {
3533  size_t dimsPosSize = dimsPos.size();
3534  if (dimsPosSize > rank)
3535  return true;
3536  DenseSet<int64_t> uniqued;
3537  for (int64_t dim : dimsPos)
3538  uniqued.insert(dim);
3539  if (dimsPosSize != uniqued.size())
3540  return true;
3541  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3542  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3543  });
3544 }
3545 
3546 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3547 /// of the `limitShape`.
3548 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3549  ArrayRef<int64_t> limitShape) {
3550  assert(
3551  sourceShape.size() == limitShape.size() &&
3552  "expected source shape rank, and limit of the shape to have same rank");
3553  return llvm::all_of(
3554  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3555  int64_t sourceExtent = std::get<0>(it);
3556  int64_t limit = std::get<1>(it);
3557  return ShapedType::isDynamic(sourceExtent) ||
3558  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3559  });
3560 }
3561 
3562 template <typename OpTy>
3564  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3565  "applies to only pack or unpack operations");
3566  Operation *op = packOrUnPack.getOperation();
3567 
3568  // Return true if we have a zero-value tile.
3569  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3570  return llvm::any_of(
3571  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3572  };
3573 
3574  // Verify tiles. Do not allow zero tiles.
3575  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3576  if (hasZeros(mixedTiles))
3577  return op->emitError("invalid zero tile factor");
3578 
3579  // Verify inner_dims_pos and outer_dims_perm.
3580  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3581  ? packOrUnPack.getSourceType()
3582  : packOrUnPack.getDestType();
3583  size_t unpackedRank = unpackedType.getRank();
3584  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3585  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3586  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3587  return op->emitError("invalid inner_dims_pos vector");
3588  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3589  return op->emitError("invalid outer_dims_perm vector");
3590  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3591  return op->emitError("outer_dims_perm must be a permutation or empty");
3592 
3593  // Tiling factors must be less than or equal to the input rank for pack (or
3594  // output rank for unpack), and must match the number of `inner_dims_pos`.
3595  if (mixedTiles.size() > unpackedRank) {
3596  return op->emitError("tiling factors must be less than or equal to the "
3597  "input rank for pack or output rank for unpack");
3598  }
3599  if (mixedTiles.size() != innerDimsPos.size()) {
3600  return op->emitError(
3601  "tiling factors must equal the number of dimensions to tile");
3602  }
3603 
3604  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3605  ? packOrUnPack.getDestType()
3606  : packOrUnPack.getSourceType();
3607  size_t packedRank = packedType.getRank();
3608  // Require output rank to match input rank + number of blocking factors.
3609  if (unpackedRank + mixedTiles.size() != packedRank) {
3610  return op->emitError(
3611  "packed rank must equal unpacked rank + tiling factors");
3612  }
3613 
3614  // Verify result shape is greater than the minimum expected
3615  // by the pack operation, and that the output shape
3616  // represents full tiles.
3617  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3618  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3619  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3620  return op->emitError("the shape of output is not large enough to hold the "
3621  "packed data. Expected at least ")
3622  << expectedPackedType << ", got " << packedType;
3623  }
3624  if (!llvm::all_of(
3625  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3626  mixedTiles),
3627  [](std::tuple<int64_t, OpFoldResult> it) {
3628  std::optional<int64_t> constTileSize =
3629  getConstantIntValue(std::get<1>(it));
3630  int64_t shape = std::get<0>(it);
3631  if (!constTileSize) {
3632  // If specified tile size is dynamic, output shape should
3633  // be dynamic too.
3634  return ShapedType::isDynamic(shape);
3635  }
3636  if (ShapedType::isDynamic(shape)) {
3637  // For the shape being dynamic when tile size is
3638  // specified, return true. In canonical form a constant
3639  // tile size should lead to constant shape of the tiled
3640  // dimension, but not needed for verification.
3641  return true;
3642  }
3643  return shape == constTileSize.value();
3644  })) {
3645  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3646  "tiled dimension in the packed type");
3647  }
3648  return success();
3649 }
3650 
3651 namespace {
3652 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3653 /// various permutations to the op.
3654 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3655 // these. These may or may not become true foldings / canonicalizations
3656 // depending on how aggressive we want to be in automatically folding
3657 // transposes.
3658 struct PackOrUnPackTransposeResult {
3659  SmallVector<int64_t> innerDimsPos;
3660  SmallVector<OpFoldResult> innerTiles;
3661  SmallVector<int64_t> outerDimsPerm;
3662 };
3663 } // namespace
3664 
3665 template <typename OpTy>
3666 static PackOrUnPackTransposeResult
3668  ArrayRef<int64_t> innerPermutation,
3669  ArrayRef<int64_t> outerPermutation) {
3670  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3671  "applies to only pack or unpack operations");
3672  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3673  "some permutation must be non-empty");
3674  PackOrUnPackTransposeResult metadata;
3675  metadata.innerDimsPos =
3676  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3677  metadata.innerTiles =
3678  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3679  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3680  ? packOrUnPackOp.getSourceRank()
3681  : packOrUnPackOp.getDestRank();
3682  metadata.outerDimsPerm =
3683  packOrUnPackOp.getOuterDimsPerm().empty()
3684  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3685  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3686  if (!innerPermutation.empty()) {
3687  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3688  isPermutationVector(innerPermutation) &&
3689  "invalid inner permutation");
3690  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3691  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3692  }
3693  if (!outerPermutation.empty()) {
3694  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3695  isPermutationVector(outerPermutation) &&
3696  "invalid outer permutation");
3697  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3698  }
3699  return metadata;
3700 }
3701 
3702 //===----------------------------------------------------------------------===//
3703 // PackOp
3704 //===----------------------------------------------------------------------===//
3705 
3706 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3707  setNameFn(getResult(), "pack");
3708 }
3709 
3710 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3711  Value dest, ArrayRef<int64_t> innerDimsPos,
3712  ArrayRef<OpFoldResult> innerTiles,
3713  std::optional<Value> paddingValue,
3714  ArrayRef<int64_t> outerDimsPerm) {
3715  assert(innerDimsPos.size() == innerTiles.size() &&
3716  "number of tile sizes specified must match the specified number of "
3717  "original dimensions to be tiled");
3718  SmallVector<int64_t> staticTileSizes;
3719  SmallVector<Value> dynamicTileSizes;
3720  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3721  build(builder, state, dest.getType(), source, dest,
3722  paddingValue ? *paddingValue : nullptr,
3723  outerDimsPerm.empty() ? nullptr
3724  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3725  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3726  builder.getDenseI64ArrayAttr(staticTileSizes));
3727 }
3728 
3731  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3732  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3733 }
3734 
3735 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3736  return getDimAndTileMappingImpl(*this);
3737 }
3738 
3739 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3740  return getMixedTilesImpl(*this);
3741 }
3742 
3743 SmallVector<int64_t> PackOp::getStaticTiles() {
3744  return getStaticTilesImpl(*this);
3745 }
3746 
3747 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3748  ArrayRef<int64_t> innerDimsPos,
3749  ArrayRef<int64_t> outputShape,
3750  ArrayRef<int64_t> outerDimsPerm,
3751  ArrayRef<OpFoldResult> innerTiles) {
3752  SmallVector<int64_t> outputTileSizes(
3753  outputShape.take_front(inputShape.size()));
3754  if (!outerDimsPerm.empty()) {
3755  assert(outerDimsPerm.size() == outputTileSizes.size() &&
3756  "expected output and outer_dims_perm to have same size");
3757  applyPermutationToVector(outputTileSizes,
3758  invertPermutationVector(outerDimsPerm));
3759  }
3760  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3761  if (ShapedType::isDynamic(inputShape[pos]))
3762  continue;
3763  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3764 
3765  if (!constantTile) {
3766  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3767  (inputShape[pos] % outputTileSizes[pos] != 0))
3768  return true;
3769  } else if (inputShape[pos] % (*constantTile) != 0) {
3770  return true;
3771  }
3772  }
3773  return false;
3774 }
3775 
3778  return failure();
3779 
3780  // Verify padding value, and bail out if the tile does not divide the
3781  // dimension fully. In the case of dynamic tile factors or dimensions, having
3782  // a partial tile is undefined behavior.
3783  auto paddingValue = getPaddingValue();
3784  if (paddingValue &&
3785  paddingValue.getType() != getSourceType().getElementType()) {
3786  return emitOpError("expected padding_value has ")
3787  << getSourceType().getElementType()
3788  << " but got: " << paddingValue.getType();
3789  }
3790 
3791  if (!paddingValue &&
3792  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3793  getDestType().getShape(), getOuterDimsPerm(),
3794  getMixedTiles())) {
3795  return emitOpError(
3796  "invalid tile factor or output size provided. Only full tiles are "
3797  "supported when padding_value is not set");
3798  }
3799  return success();
3800 }
3801 
3802 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3803 /// Value's to kDynamic, even if they are arith.constant values.
3804 static SmallVector<int64_t>
3806  SmallVector<int64_t> result;
3807  for (auto o : ofrs) {
3808  // Have to do this first, as getConstantIntValue special-cases constants.
3809  if (llvm::dyn_cast_if_present<Value>(o))
3810  result.push_back(ShapedType::kDynamic);
3811  else
3812  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3813  }
3814  return result;
3815 }
3816 
3817 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3818 /// the packed type. Having a shared helper helps implement these two methods in
3819 /// a way that ensures that they agree on which dimensions are dynamic.
3821  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3822  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3823  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3824  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3825  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3826  continue;
3827  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3828  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3829  continue;
3830  }
3831  resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
3832  innerTileSizes[tiledDim.index()]);
3833  }
3834 
3835  // Swap tile loops if outer_dims_perm is available.
3836  if (!outerDimsPerm.empty())
3837  applyPermutationToVector(resultShape, outerDimsPerm);
3838 
3839  // Append the inner tile dimensions.
3840  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3841  return resultShape;
3842 }
3843 
3844 SmallVector<OpFoldResult> PackOp::getResultShape(
3845  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
3846  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
3847  ArrayRef<int64_t> outerDimsPerm) {
3848  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
3849 
3850  AffineExpr s0, s1;
3851  bindSymbols(builder.getContext(), s0, s1);
3852  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
3853  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3854  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
3855  builder, loc, ceilDivExpr,
3856  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3857  }
3858  if (!outerDimsPerm.empty())
3859  applyPermutationToVector(resultDims, outerDimsPerm);
3860  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3861 
3862  SmallVector<int64_t> resultTypeShape =
3864  asShapeWithAnyValueAsDynamic(innerTileSizes),
3865  innerDimsPos, outerDimsPerm);
3866 
3867  // Fix-up `resultDims` to ensure that they are Value's if and only if the
3868  // result type shape says it's a dynamic dim. This is needed as callers may
3869  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
3870  // dynamic dims returned by that.
3871  for (unsigned i = 0; i < resultDims.size(); ++i) {
3872  if (!ShapedType::isDynamic(resultTypeShape[i]))
3873  continue;
3874  resultDims[i] =
3875  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
3876  }
3877 
3878  return resultDims;
3879 }
3880 
3881 /// Get the expected packed type based on source type, tile factors, position of
3882 /// the inner tiles and permutation of the outer tiled loop.
3883 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
3884  ArrayRef<int64_t> innerTileSizes,
3885  ArrayRef<int64_t> innerDimsPos,
3886  ArrayRef<int64_t> outerDimsPerm) {
3888  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3889  return RankedTensorType::get(resultShape, sourceType.getElementType());
3890 }
3891 
3892 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
3893  ArrayRef<OpFoldResult> innerTileSizes,
3894  ArrayRef<int64_t> innerDimsPos,
3895  ArrayRef<int64_t> outerDimsPerm) {
3896  AffineExpr dim0, dim1;
3897  bindDims(b.getContext(), dim0, dim1);
3898  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
3899  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
3900  {v1, v2});
3901  };
3902 
3903  SmallVector<OpFoldResult> mixedSizes;
3904  for (auto [index, value] : llvm::enumerate(
3905  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
3906  if (ShapedType::isDynamic(value))
3907  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
3908  else
3909  mixedSizes.push_back(b.getIndexAttr(value));
3910  }
3911  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3912  int64_t dimPos = std::get<0>(it);
3913  OpFoldResult tileSize = std::get<1>(it);
3914  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
3915  }
3916  if (!outerDimsPerm.empty())
3917  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3918 
3919  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3920  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3921  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3922 }
3923 
3924 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
3925  ArrayRef<int64_t> innerPermutation,
3926  ArrayRef<int64_t> outerPermutation) {
3927  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3928  *this, innerPermutation, outerPermutation);
3929  Value transposedDest =
3930  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3931  metadata.innerDimsPos, metadata.outerDimsPerm);
3932  return b.create<PackOp>(loc, getSource(), transposedDest,
3933  metadata.innerDimsPos, metadata.innerTiles,
3934  getPaddingValue(), metadata.outerDimsPerm);
3935 }
3936 
3937 /// Returns true if the tiles and the tiled dims are constant.
3938 template <typename OpTy>
3940  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3941  "applies to only pack or unpack operations");
3942  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3943  ? op.getDestType()
3944  : op.getSourceType();
3945  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
3946  for (auto [dimDest, tile] : llvm::zip(
3947  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3948  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
3949  if (!constTileSize || ShapedType::isDynamic(dimDest))
3950  return false;
3951  }
3952  return true;
3953 }
3954 
3955 Speculation::Speculatability PackOp::getSpeculatability() {
3956  if (getPaddingValue())
3958 
3959  // The verifier rejects already operations if we can statically prove that the
3960  // sizes of the tiles do not divide perfectly the dimension; thus, check only
3961  // to have constant tiles and tiled inner dimensions.
3962  if (!areTilesAndTiledDimsAllConstant(*this))
3964 
3966 }
3967 
3968 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
3969 // dimensions for pack and unpack.
3970 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
3971  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3972  return false;
3973  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3974 }
3975 
3976 // Return true if pack and unpack have the same tiles.
3977 // Same SSA values or same integer constants.
3978 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
3979  auto packTiles = packOp.getMixedTiles();
3980  auto unPackTiles = unPackOp.getMixedTiles();
3981  if (packTiles.size() != unPackTiles.size())
3982  return false;
3983  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
3984  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
3985  return false;
3986  }
3987  return true;
3988 }
3989 
3990 /// Returns true if the pack op does not need a padding value.
3991 static bool paddingIsNotNeeded(PackOp op) {
3992  auto srcType = op.getSourceType();
3993  if (llvm::any_of(op.getInnerDimsPos(),
3994  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
3995  return false;
3996  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
3997  return false;
3998  return !PackOp::requirePaddingValue(
3999  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4000  op.getOuterDimsPerm(), op.getMixedTiles());
4001 }
4002 
4003 /// Returns true if the `srcShape` or `destShape` is different from the one in
4004 /// `packOp` and populates each with the inferred static shape.
4005 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4006  SmallVectorImpl<int64_t> &destShape) {
4007  bool changeNeeded = false;
4008  srcShape.assign(packOp.getSourceType().getShape().begin(),
4009  packOp.getSourceType().getShape().end());
4010  destShape.assign(packOp.getDestType().getShape().begin(),
4011  packOp.getDestType().getShape().end());
4012  llvm::SmallSetVector<int64_t, 4> innerDims;
4013  innerDims.insert(packOp.getInnerDimsPos().begin(),
4014  packOp.getInnerDimsPos().end());
4015  auto outerDimsPerm = packOp.getOuterDimsPerm();
4016  int srcRank = packOp.getSourceRank();
4017  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4018  if (innerDims.contains(i))
4019  continue;
4020  int64_t srcPos = i;
4021  int64_t destPos = i;
4022  if (!outerDimsPerm.empty())
4023  destPos = outerDimsPerm[srcPos];
4024  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4025  ShapedType::isDynamic(destShape[destPos])) {
4026  continue;
4027  }
4028  int64_t size = srcShape[srcPos];
4029  if (ShapedType::isDynamic(size))
4030  size = destShape[destPos];
4031  srcShape[srcPos] = size;
4032  destShape[destPos] = size;
4033  changeNeeded = true;
4034  }
4035  return changeNeeded;
4036 }
4037 
4038 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4039  // Fold an unpack(pack(x)) to x.
4040  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4041  if (unPackOp.getSourceType() != packOp.getDestType())
4042  return failure();
4043  if (packOp.getPaddingValue() ||
4044  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4045  !haveSameTiles(packOp, unPackOp))
4046  return failure();
4047  rewriter.replaceOp(packOp, unPackOp.getSource());
4048  return success();
4049  }
4050 
4051  // Fold optional PaddingValue operand away if padding is not needed.
4052  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4053  rewriter.startOpModification(packOp);
4054  packOp.getPaddingValueMutable().clear();
4055  rewriter.finalizeOpModification(packOp);
4056  return success();
4057  }
4058 
4059  // Insert tensor.cast ops if static shape inference is available..
4060  SmallVector<int64_t> srcShape, destShape;
4061  if (inferStaticShape(packOp, srcShape, destShape)) {
4062  Location loc = packOp.getLoc();
4063  Value source = packOp.getSource();
4064  if (srcShape != packOp.getSourceType().getShape()) {
4065  auto newSrcType = packOp.getSourceType().clone(srcShape);
4066  source =
4067  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4068  }
4069  Value dest = packOp.getDest();
4070  if (destShape != packOp.getDestType().getShape()) {
4071  auto newDestType = packOp.getDestType().clone(destShape);
4072  dest =
4073  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4074  }
4075  Value newOp = rewriter.create<tensor::PackOp>(
4076  loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4077  packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4078  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4079  packOp, packOp.getResult().getType(), newOp);
4080  return success();
4081  }
4082 
4083  return failure();
4084 }
4085 
4086 template <typename PackOrUnpackOp>
4087 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4088  RankedTensorType packedTensorType) {
4089  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4090  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4091  "Function meant for pack/unpack");
4092  // This is a pad if packing only adds ones and we don't transpose dimensions.
4093 
4094  // Check that we are not transposing any dimensions.
4095  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4096  int64_t numPackedDims = innerDimsPos.size();
4097  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4098  if (orderedDims != innerDimsPos) {
4099  // Dimensions don't happen in order.
4100  return false;
4101  }
4102 
4103  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4104  int64_t packedRank = packedTensorType.getRank();
4105  // At this point we know that we are taking numPackedDims outer
4106  // dimensions and pushing them all the way as the inner most dimensions.
4107  // What's left on the outer most dimensions is, in this order:
4108  // - the factor of the packed dimensions, then
4109  // - the untouched dimensions
4110  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4111  // if all the dimensions that bubble outerward are ones.
4112  // Therefore check that all the dimensions but the numPackedDims inner most
4113  // ones are ones.
4114  return llvm::all_of(
4115  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4116  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4117 }
4118 
4119 bool PackOp::isLikePad() {
4120  auto packedTensorType =
4121  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4122  return isLikePadUnPad(*this, packedTensorType);
4123 }
4124 
4125 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4126  if (OpFoldResult reshapedSource = reshapeConstantSource(
4127  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4128  getResult().getType()))
4129  return reshapedSource;
4130  return {};
4131 }
4132 
4133 //===----------------------------------------------------------------------===//
4134 // UnPackOp
4135 //===----------------------------------------------------------------------===//
4136 
4137 void UnPackOp::getAsmResultNames(
4138  function_ref<void(Value, StringRef)> setNameFn) {
4139  setNameFn(getResult(), "unpack");
4140 }
4141 
4144  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4145  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4146 }
4147 
4148 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4149  return getDimAndTileMappingImpl(*this);
4150 }
4151 
4152 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4153  return getMixedTilesImpl(*this);
4154 }
4155 
4156 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4157  return getStaticTilesImpl(*this);
4158 }
4159 
4161  return commonVerifierPackAndUnPackOp(*this);
4162 }
4163 
4164 Speculation::Speculatability UnPackOp::getSpeculatability() {
4165  // See PackOp::getSpeculatability.
4166  if (!areTilesAndTiledDimsAllConstant(*this))
4168 
4170 }
4171 
4172 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4173  Value dest, ArrayRef<int64_t> innerDimsPos,
4174  ArrayRef<OpFoldResult> innerTiles,
4175  ArrayRef<int64_t> outerDimsPerm) {
4176  assert(innerDimsPos.size() == innerTiles.size() &&
4177  "number of tile sizes specified must match the specified number of "
4178  "original dimensions to be tiled");
4179  SmallVector<int64_t> staticTileSizes;
4180  SmallVector<Value> dynamicTileSizes;
4181  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4182  build(builder, state, dest.getType(), source, dest,
4183  outerDimsPerm.empty() ? nullptr
4184  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4185  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4186  builder.getDenseI64ArrayAttr(staticTileSizes));
4187 }
4188 
4189 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4190  Value source,
4191  ArrayRef<OpFoldResult> innerTileSizes,
4192  ArrayRef<int64_t> innerDimsPos,
4193  ArrayRef<int64_t> outerDimsPerm) {
4194  AffineExpr sym0, sym1;
4195  bindSymbols(b.getContext(), sym0, sym1);
4196  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4197  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4198  };
4199 
4200  SmallVector<OpFoldResult> mixedSizes;
4201  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4202  for (auto i :
4203  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4204  if (srcType.isDynamicDim(i))
4205  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4206  else
4207  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4208  }
4209  if (!outerDimsPerm.empty()) {
4210  applyPermutationToVector<OpFoldResult>(
4211  mixedSizes, invertPermutationVector(outerDimsPerm));
4212  }
4213 
4214  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4215  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4216 
4217  auto elemType = srcType.getElementType();
4218  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4219 }
4220 
4221 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4222  Value transposedSource,
4223  ArrayRef<int64_t> innerPermutation,
4224  ArrayRef<int64_t> outerPermutation) {
4225  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4226  *this, innerPermutation, outerPermutation);
4227  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4228  metadata.innerDimsPos, metadata.innerTiles,
4229  metadata.outerDimsPerm);
4230 }
4231 
4232 /// Returns true if the `srcShape` or `destShape` is different from the one in
4233 /// `op` and populates each with the inferred static shape.
4234 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4235  SmallVectorImpl<int64_t> &destShape) {
4236  bool changeNeeded = false;
4237  srcShape.assign(op.getSourceType().getShape().begin(),
4238  op.getSourceType().getShape().end());
4239  destShape.assign(op.getDestType().getShape().begin(),
4240  op.getDestType().getShape().end());
4241  llvm::SmallSetVector<int64_t, 4> innerDims;
4242  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4243  auto outerDimsPerm = op.getOuterDimsPerm();
4244  int destRank = op.getDestRank();
4245  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4246  if (innerDims.contains(i))
4247  continue;
4248  int64_t srcPos = i;
4249  int64_t destPos = i;
4250  if (!outerDimsPerm.empty())
4251  srcPos = outerDimsPerm[destPos];
4252  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4253  ShapedType::isDynamic(destShape[destPos])) {
4254  continue;
4255  }
4256  int64_t size = srcShape[srcPos];
4257  if (ShapedType::isDynamic(size))
4258  size = destShape[destPos];
4259  srcShape[srcPos] = size;
4260  destShape[destPos] = size;
4261  changeNeeded = true;
4262  }
4263  return changeNeeded;
4264 }
4265 
4266 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4267  PatternRewriter &rewriter) {
4268  /// pack(unpack(x)) -> x
4269  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4270  if (packOp.getDestType() != unPackOp.getSourceType())
4271  return failure();
4272  if (packOp.getPaddingValue() ||
4273  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4274  !haveSameTiles(packOp, unPackOp))
4275  return failure();
4276  rewriter.replaceOp(unPackOp, packOp.getSource());
4277  return success();
4278  }
4279  /// unpack(destinationStyleOp(x)) -> unpack(x)
4280  if (auto dstStyleOp =
4281  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4282  auto destValue = unPackOp.getDest().cast<OpResult>();
4283  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4284  rewriter.modifyOpInPlace(unPackOp,
4285  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4286  return success();
4287  }
4288 
4289  // Insert tensor.cast ops if static shape inference is available..
4290  SmallVector<int64_t> srcShape, destShape;
4291  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4292  Location loc = unPackOp.getLoc();
4293  Value source = unPackOp.getSource();
4294  if (srcShape != unPackOp.getSourceType().getShape()) {
4295  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4296  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4297  unPackOp.getSource());
4298  }
4299  Value dest = unPackOp.getDest();
4300  if (destShape != unPackOp.getDestType().getShape()) {
4301  auto newDestType = unPackOp.getDestType().clone(destShape);
4302  dest =
4303  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4304  }
4305  Value newOp = rewriter.create<tensor::UnPackOp>(
4306  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4307  unPackOp.getOuterDimsPerm());
4308  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4309  unPackOp, unPackOp.getResult().getType(), newOp);
4310  return success();
4311  }
4312 
4313  return failure();
4314 }
4315 
4316 bool UnPackOp::isLikeUnPad() {
4317  RankedTensorType packedTensorType = getSourceType();
4318  return isLikePadUnPad(*this, packedTensorType);
4319 }
4320 
4321 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4322  if (OpFoldResult reshapedSource = reshapeConstantSource(
4323  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4324  getResult().getType()))
4325  return reshapedSource;
4326  return {};
4327 }
4328 
4329 //===----------------------------------------------------------------------===//
4330 // Common Canonicalizers and Folders.
4331 //===----------------------------------------------------------------------===//
4332 
4333 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4334 /// the `tensor.cast` has source that is more static than the consuming op.
4335 ///
4336 /// Example:
4337 /// ```mlir
4338 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4339 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4340 /// ```
4341 ///
4342 /// folds into:
4343 ///
4344 /// ```mlir
4345 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4346 /// ```
4347 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4348 /// can add the pattern to their canonicalizers.
4350  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4352  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4353 
4354  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4355  PatternRewriter &rewriter) const override {
4356  // InsertSliceOp has its own logic about folding tensor.cast ops.
4357  if (isa<InsertSliceOp>(op.getOperation()))
4358  return failure();
4359 
4360  // Exclude DPS ops that are also LoopLike from this interface as they
4361  // might need special handling of attached regions.
4362  if (isa<LoopLikeOpInterface>(op.getOperation()))
4363  return failure();
4364 
4365  // If no operand comes from a tensor::CastOp and can be folded then fail.
4366  bool hasTensorCastOperand =
4367  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4368  if (llvm::isa<BlockArgument>(opOperand.get()))
4369  return false;
4370  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4371  return castOp && canFoldIntoConsumerOp(castOp);
4372  });
4373  if (!hasTensorCastOperand)
4374  return failure();
4375 
4376  SmallVector<Type, 4> newResultTypes;
4377  newResultTypes.reserve(op->getNumResults());
4378  SmallVector<Value, 4> newOperands;
4379  newOperands.reserve(op->getNumOperands());
4380  for (OpOperand &opOperand : op->getOpOperands()) {
4381  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4382  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4383  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4384  if (op.isDpsInit(&opOperand) &&
4385  !llvm::isa<MemRefType>(newOperands.back().getType()))
4386  newResultTypes.push_back(newOperands.back().getType());
4387  }
4388 
4389  // Clone op.
4390  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4391  SmallVector<Value, 4> replacements;
4392  replacements.reserve(newOp->getNumResults());
4393  for (auto [oldResult, newResult] :
4394  llvm::zip(op->getResults(), newOp->getResults())) {
4395  if (newResult.getType() != oldResult.getType()) {
4396  replacements.push_back(rewriter.create<tensor::CastOp>(
4397  op->getLoc(), oldResult.getType(), newResult));
4398  } else {
4399  replacements.push_back(newResult);
4400  }
4401  }
4402  rewriter.replaceOp(op, replacements);
4403 
4404  return success();
4405  }
4406 };
4407 
4408 //===----------------------------------------------------------------------===//
4409 // TensorDialect
4410 //===----------------------------------------------------------------------===//
4411 
4412 void TensorDialect::getCanonicalizationPatterns(
4413  RewritePatternSet &results) const {
4415 }
4416 
4417 //===----------------------------------------------------------------------===//
4418 // TableGen'd op method definitions
4419 //===----------------------------------------------------------------------===//
4420 
4421 #define GET_OP_CLASSES
4422 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:3939
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:379
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1255
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3667
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2032
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3487
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3518
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:3991
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2733
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3820
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3805
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3502
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1042
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2453
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2305
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2327
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4005
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3548
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1508
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3978
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2414
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4087
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3563
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:175
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3475
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3531
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2729
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:131
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3970
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2473
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1642
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:925
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense vector or tensor object.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:249
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:260
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:631
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:615
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1185
std::optional< Operation::operand_range > getIndices(Operation *op)
Get and set the indices that the given load/store operation is operating on.
Definition: Utils.cpp:10
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MPInt getIndex(ConeV cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:350
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2265
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:341
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:311
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2706
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2352
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:119
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:51
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:70
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:369
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:825
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:29
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4350
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4354
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2284
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2285
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2272
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2273
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)