MLIR  19.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include <algorithm>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::tensor;
37 
38 /// Materialize a single constant operation from a given attribute value with
39 /// the desired resultant type.
41  Attribute value, Type type,
42  Location loc) {
43  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
44  return op;
45  if (complex::ConstantOp::isBuildableWith(value, type))
46  return builder.create<complex::ConstantOp>(loc, type,
47  llvm::cast<ArrayAttr>(value));
48  return nullptr;
49 }
50 
52  int64_t dim) {
53  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
55  if (tensorType.isDynamicDim(dim))
56  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
57 
58  return builder.getIndexAttr(tensorType.getDimSize(dim));
59 }
60 
62  Location loc, Value value) {
63  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
65  for (int64_t i = 0; i < tensorType.getRank(); ++i)
66  result.push_back(getMixedSize(builder, loc, value, i));
67  return result;
68 }
69 
71  OpResult opResult) {
72  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
73  assert(tensorType && "expected tensor type");
74 
75  // If the op has a destination, it implements DestinationStyleOpInterface and
76  // we can query the destination operand from that interface.
77  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
78  if (destOp)
79  return destOp.getTiedOpOperand(opResult)->get();
80 
81  // Otherwise, create a new destination tensor with the same shape.
83  b.setInsertionPoint(opResult.getDefiningOp());
84 
85  // Compute sizes.
86  SmallVector<OpFoldResult> mixedSizes;
87  if (!tensorType.hasStaticShape()) {
88  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
89  ReifiedRankedShapedTypeDims reifiedShapes;
90  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
91  return failure();
92  mixedSizes = reifiedShapes[opResult.getResultNumber()];
93  } else {
94  // Static shape: Take static sizes directly.
95  for (int64_t sz : tensorType.getShape())
96  mixedSizes.push_back(b.getIndexAttr(sz));
97  }
98 
99  // Create empty tensor.
100  Value emptyTensor =
101  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
102  return emptyTensor;
103 }
104 
106  Operation *op,
107  SmallVector<Value> &result) {
108  for (OpResult opResult : op->getResults()) {
109  if (llvm::isa<TensorType>(opResult.getType())) {
110  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
111  if (failed(destination))
112  return failure();
113  result.push_back(*destination);
114  }
115  }
116  return success();
117 }
118 
120  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
121  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
122  return rtp1.getShape() == rtp2.getShape() &&
123  rtp1.getElementType() == rtp2.getElementType();
124  return false;
125  }
126  return tp1 == tp2; // default implementation
127 }
128 
129 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
130 /// rank-extending tensor.insert_slice op.
131 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
132  ArrayRef<OpFoldResult> mixedSizes) {
133  llvm::SmallBitVector droppedDims(mixedSizes.size());
134  int64_t shapePos = 0;
135 
136  for (const auto &size : enumerate(mixedSizes)) {
137  // Rank-reduced dims must have a static unit dimension.
138  bool isStaticUnitSize =
139  size.value().is<Attribute>() &&
140  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
141 
142  if (shapePos == static_cast<int64_t>(reducedShape.size())) {
143  // There are no more dims in the reduced shape. All remaining sizes must
144  // be rank-reduced dims.
145  assert(isStaticUnitSize && "expected unit dim");
146  droppedDims.set(size.index());
147  continue;
148  }
149 
150  // Dim is preserved if the size is not a static 1.
151  if (!isStaticUnitSize) {
152  ++shapePos;
153  continue;
154  }
155 
156  // Dim is preserved if the reduced shape dim is also 1.
157  if (reducedShape[shapePos] == 1) {
158  ++shapePos;
159  continue;
160  }
161 
162  // Otherwise: Dim is dropped.
163  droppedDims.set(size.index());
164  }
165 
166  assert(shapePos == static_cast<int64_t>(reducedShape.size()) &&
167  "dimension mismatch");
168  return droppedDims;
169 }
170 
171 /// Given a ranked tensor type and a range of values that defines its dynamic
172 /// dimension sizes, turn all dynamic sizes that have a constant value into
173 /// static dimension sizes.
174 static RankedTensorType
175 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
176  SmallVector<Value> &foldedDynamicSizes) {
177  SmallVector<int64_t> staticShape(type.getShape().begin(),
178  type.getShape().end());
179  assert(type.getNumDynamicDims() ==
180  static_cast<int64_t>(dynamicSizes.size()) &&
181  "incorrect number of dynamic sizes");
182 
183  // Compute new static and dynamic sizes.
184  unsigned ctr = 0;
185  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
186  if (type.isDynamicDim(i)) {
187  Value dynamicSize = dynamicSizes[ctr++];
188  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
189  if (cst.has_value()) {
190  // Dynamic size must be non-negative.
191  if (cst.value() < 0) {
192  foldedDynamicSizes.push_back(dynamicSize);
193  continue;
194  }
195  staticShape[i] = *cst;
196  } else {
197  foldedDynamicSizes.push_back(dynamicSize);
198  }
199  }
200  }
201 
202  return RankedTensorType::get(staticShape, type.getElementType(),
203  type.getEncoding());
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // BitcastOp
208 //===----------------------------------------------------------------------===//
209 
210 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
211  if (inputs.size() != 1 || outputs.size() != 1)
212  return false;
213  Type a = inputs.front(), b = outputs.front();
214  auto aT = dyn_cast<TensorType>(a);
215  auto bT = dyn_cast<TensorType>(b);
216  if (!aT || !bT)
217  return false;
218 
219  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
220  return false;
221 
222  return succeeded(verifyCompatibleShape(aT, bT));
223 }
224 
225 namespace {
226 
227 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
228 /// operation.
229 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
231 
232  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
233  PatternRewriter &rewriter) const final {
234  auto tensorBitcastOperand =
235  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
236  if (!tensorBitcastOperand)
237  return failure();
238 
239  auto resultType = cast<TensorType>(tensorBitcast.getType());
240  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
241  tensorBitcastOperand.getOperand());
242  return success();
243  }
244 };
245 
246 } // namespace
247 
248 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
249  MLIRContext *context) {
250  results.add<ChainedTensorBitcast>(context);
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CastOp
255 //===----------------------------------------------------------------------===//
256 
257 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
258  setNameFn(getResult(), "cast");
259 }
260 
261 /// Returns true if `target` is a ranked tensor type that preserves static
262 /// information available in the `source` ranked tensor type.
264  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
265  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
266 
267  // Requires RankedTensorType.
268  if (!sourceType || !targetType)
269  return false;
270 
271  // Requires same elemental type.
272  if (sourceType.getElementType() != targetType.getElementType())
273  return false;
274 
275  // Requires same rank.
276  if (sourceType.getRank() != targetType.getRank())
277  return false;
278 
279  // Requires same encoding.
280  if (sourceType.getEncoding() != targetType.getEncoding())
281  return false;
282 
283  // If cast is towards more static sizes along any dimension, don't fold.
284  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
285  if (!ShapedType::isDynamic(std::get<0>(t)) &&
286  ShapedType::isDynamic(std::get<1>(t)))
287  return false;
288  }
289 
290  return true;
291 }
292 
293 /// Determines whether tensor::CastOp casts to a more dynamic version of the
294 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
295 /// implement canonicalization patterns for ops in different dialects that may
296 /// consume the results of tensor.cast operations. Such foldable tensor.cast
297 /// operations are typically inserted as `slice` ops and are canonicalized,
298 /// to preserve the type compatibility of their uses.
299 ///
300 /// Returns true when all conditions are met:
301 /// 1. source and result are ranked tensors with same element type and rank.
302 /// 2. the tensor type has more static information than the result
303 ///
304 /// Example:
305 /// ```mlir
306 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
307 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
308 /// ```
309 ///
310 /// folds into:
311 ///
312 /// ```mlir
313 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
314 /// ```
316  if (!castOp)
317  return false;
318 
319  // Can fold if the source of cast has at least as much static information as
320  // its results.
321  return preservesStaticInformation(castOp.getType(),
322  castOp.getSource().getType());
323 }
324 
325 /// Determines whether the tensor::CastOp casts to a more static version of the
326 /// source tensor. This is useful to fold into a producing op and implement
327 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
328 /// being from different dialects. Returns true when all conditions are met:
329 /// 1. source and result and ranked tensors with same element type and rank.
330 /// 2. the result type has more static information than the source.
331 ///
332 /// Example:
333 /// ```mlir
334 /// %1 = producer ... : tensor<?x?xf32>
335 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
336 /// ```
337 ///
338 /// can be canonicalized to :
339 ///
340 /// ```mlir
341 /// %2 = producer ... : tensor<8x16xf32>
342 /// ```
343 /// Not all ops might be canonicalizable this way, but for those that can be,
344 /// this method provides a check that it is worth doing the canonicalization.
346  if (!castOp)
347  return false;
348  return preservesStaticInformation(castOp.getSource().getType(),
349  castOp.getType());
350 }
351 
352 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
353 /// that can be folded.
355  bool folded = false;
356  for (OpOperand &operand : op->getOpOperands()) {
357  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
358  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
359  operand.set(castOp.getOperand());
360  folded = true;
361  }
362  }
363  return success(folded);
364 }
365 
366 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
367  if (inputs.size() != 1 || outputs.size() != 1)
368  return false;
369  Type a = inputs.front(), b = outputs.front();
370  auto aT = llvm::dyn_cast<TensorType>(a);
371  auto bT = llvm::dyn_cast<TensorType>(b);
372  if (!aT || !bT)
373  return false;
374 
375  if (aT.getElementType() != bT.getElementType())
376  return false;
377 
378  return succeeded(verifyCompatibleShape(aT, bT));
379 }
380 
381 /// Compute a TensorType that has the joined shape knowledge of the two
382 /// given TensorTypes. The element types need to match.
384  assert(one.getElementType() == two.getElementType());
385 
386  if (!one.hasRank())
387  return two;
388  if (!two.hasRank())
389  return one;
390 
391  int64_t rank = one.getRank();
392  if (rank != two.getRank())
393  return {};
394 
396  join.reserve(rank);
397  for (int64_t i = 0; i < rank; ++i) {
398  if (one.isDynamicDim(i)) {
399  join.push_back(two.getDimSize(i));
400  continue;
401  }
402  if (two.isDynamicDim(i)) {
403  join.push_back(one.getDimSize(i));
404  continue;
405  }
406  if (one.getDimSize(i) != two.getDimSize(i))
407  return {};
408  join.push_back(one.getDimSize(i));
409  }
410  return RankedTensorType::get(join, one.getElementType());
411 }
412 
413 namespace {
414 
415 /// Replaces chains of two tensor.cast operations by a single tensor.cast
416 /// operation if doing so does not remove runtime constraints.
417 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
419 
420  LogicalResult matchAndRewrite(CastOp tensorCast,
421  PatternRewriter &rewriter) const final {
422  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
423 
424  if (!tensorCastOperand)
425  return failure();
426 
427  auto sourceType =
428  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
429  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
430  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
431 
432  // We can remove the intermediate cast if joining all three produces the
433  // same result as just joining the source and result shapes.
434  auto firstJoin =
435  joinShapes(joinShapes(sourceType, intermediateType), resultType);
436 
437  // The join might not exist if the cast sequence would fail at runtime.
438  if (!firstJoin)
439  return failure();
440 
441  // The newJoin always exists if the above join exists, it might just contain
442  // less information. If so, we cannot drop the intermediate cast, as doing
443  // so would remove runtime checks.
444  auto newJoin = joinShapes(sourceType, resultType);
445  if (firstJoin != newJoin)
446  return failure();
447 
448  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
449  tensorCastOperand.getOperand());
450  return success();
451  }
452 };
453 
454 /// Fold tensor.cast into tesor.extract_slice producer.
455 /// Example:
456 /// ```
457 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
458 /// tensor<128x512xf32> to tensor<?x512xf32>
459 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
460 /// ```
461 /// ->
462 /// ```
463 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
464 /// tensor<128x512xf32> to tensor<16x512xf32>
465 /// ```
466 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
468 
469  LogicalResult matchAndRewrite(CastOp tensorCast,
470  PatternRewriter &rewriter) const final {
471  auto extractOperand =
472  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
473 
474  // Cannot fold cast to unranked tensor.
475  auto rankedResultType =
476  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
477  if (!rankedResultType)
478  return failure();
479 
480  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
481  rankedResultType.getShape() ==
482  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
483  .getShape())
484  return failure();
485 
486  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
487  auto dimMask = computeRankReductionMask(
488  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
489  size_t dimIndex = 0;
490  for (size_t i = 0, e = sizes.size(); i < e; i++) {
491  if (dimMask && dimMask->count(i))
492  continue;
493  int64_t dim = rankedResultType.getShape()[dimIndex++];
494  if (ShapedType::isDynamic(dim))
495  continue;
496  sizes[i] = rewriter.getIndexAttr(dim);
497  }
498 
499  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
500  tensorCast, rankedResultType, extractOperand.getSource(),
501  extractOperand.getMixedOffsets(), sizes,
502  extractOperand.getMixedStrides());
503  return success();
504  }
505 };
506 
507 } // namespace
508 
509 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
510  MLIRContext *context) {
511  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // ConcatOp
516 //===----------------------------------------------------------------------===//
517 
518 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
519  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
520  auto tensorTypes =
521  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
522  return llvm::cast<RankedTensorType>(type);
523  }));
524  int64_t concatRank = tensorTypes[0].getRank();
525 
526  // The concatenation dim must be in the range [0, rank).
527  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
528 
529  SmallVector<int64_t> sizes(concatRank);
530  for (int64_t i = 0, e = concatRank; i < e; ++i) {
531  if (i == dim)
532  continue;
533  SaturatedInteger size;
534  for (auto tensorType : tensorTypes)
535  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
536  sizes[i] = size.asInteger();
537  }
538  auto concatSize = SaturatedInteger::wrap(0);
539  for (auto tensorType : tensorTypes)
540  concatSize =
541  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
542  sizes[dim] = concatSize.asInteger();
543  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
544 }
545 
546 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
547  ValueRange inputs) {
548  FailureOr<RankedTensorType> resultType =
549  inferResultType(dim, inputs.getTypes());
550  assert(succeeded(resultType) && "failed to infer concatenation result type");
551  build(builder, result, *resultType, dim, inputs);
552 }
553 
555  if (getInputs().size() < 1)
556  return emitOpError("requires at least one input");
557 
559  for (auto input : getInputs())
560  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
561 
562  RankedTensorType resultType = getResultType();
563  int64_t resultRank = getRank();
564  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
565  return type.getRank() != resultRank;
566  }))
567  return emitOpError("rank of concatenated inputs must match result rank");
568 
569  Type resultElementType = resultType.getElementType();
570  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
571  return type.getElementType() != resultElementType;
572  }))
573  return emitOpError("inputs and result element type must match");
574 
575  int64_t dim = getDim();
576  if (dim >= resultRank)
577  return emitOpError("concatenation dim must be less than the tensor rank");
578 
579  SmallVector<int64_t> sizes(resultRank);
580  for (int64_t i = 0, e = resultRank; i < e; ++i) {
581  if (i == dim)
582  continue;
583  SaturatedInteger size;
584  for (auto tensorType : inputTypes) {
585  FailureOr<SaturatedInteger> maybeSize =
586  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
587  if (failed(maybeSize))
588  return emitOpError("static concatenation size mismatch along ")
589  << "non-concatenated dimension " << i;
590  size = *maybeSize;
591  }
592  sizes[i] = size.asInteger();
593  }
594  auto concatSize = SaturatedInteger::wrap(0);
595  for (auto tensorType : inputTypes)
596  concatSize =
597  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
598  sizes[dim] = concatSize.asInteger();
599  auto inferredResultType =
600  RankedTensorType::get(sizes, inputTypes[0].getElementType());
601 
602  for (auto [inferredSize, actualSize] :
603  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
604  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
605  ShapedType::isDynamic(actualSize);
606  if (!hasDynamic && inferredSize != actualSize)
607  return emitOpError("result type ")
608  << resultType << "does not match inferred shape "
609  << inferredResultType << " static sizes";
610  }
611 
612  return success();
613 }
614 
617  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
618  ValueRange inputs = getInputs();
619  int64_t dim = getDim();
620  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
621 
622  Value init = inputs[0];
623  int64_t rank = getType().getRank();
624 
625  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
626 
627  // Pre-populate the result sizes with as much static information as possible
628  // from the given result type, as well as the inferred result type, otherwise
629  // use the dim sizes from the first input.
630  for (int64_t i = 0; i < rank; ++i) {
631  if (i == dim)
632  continue;
633  if (!getType().isDynamicDim(i)) {
634  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
635  } else if (!inferredResultType.isDynamicDim(i)) {
636  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
637  builder, getLoc(),
638  builder.getIndexAttr(inferredResultType.getDimSize(i)));
639  } else {
640  reifiedReturnShapes[0][i] =
641  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
642  }
643  }
644 
645  if (getType().isDynamicDim(dim)) {
646  // Take the sum of the input sizes along the concatenated dim.
647  AffineExpr sum = builder.getAffineDimExpr(0);
648  SmallVector<OpFoldResult> sizes = {
649  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
650  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
651  sum = sum + builder.getAffineDimExpr(idx + 1);
652  sizes.push_back(
653  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
654  }
655  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
656  builder, getLoc(),
657  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
658  } else {
659  // If the result shape is static along the concatenated dim, use the static
660  // shape.
661  reifiedReturnShapes[0][dim] =
662  builder.getIndexAttr(getType().getDimSize(dim));
663  }
664  return success();
665 }
666 
667 void ConcatOp::getAsmResultNames(
668  function_ref<void(Value, StringRef)> setNameFn) {
669  setNameFn(getResult(), "concat");
670 }
671 
672 OpFoldResult ConcatOp::fold(FoldAdaptor) {
673  ValueRange inputs = getInputs();
674  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
675  return inputs[0];
676  return {};
677 }
678 
679 namespace {
680 /// Fold a concat op with a single input to a cast.
681 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
683 
684  LogicalResult matchAndRewrite(ConcatOp concatOp,
685  PatternRewriter &rewriter) const override {
686  if (concatOp.getInputs().size() != 1)
687  return failure();
688  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
689  concatOp.getInputs()[0]);
690  return success();
691  }
692 };
693 } // namespace
694 
695 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
696  MLIRContext *context) {
697  results.add<SingleInputConcatOp>(context);
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // DimOp
702 //===----------------------------------------------------------------------===//
703 
704 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
705  setNameFn(getResult(), "dim");
706 }
707 
708 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
709  int64_t index) {
710  auto loc = result.location;
711  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
712  build(builder, result, source, indexValue);
713 }
714 
715 std::optional<int64_t> DimOp::getConstantIndex() {
716  return getConstantIntValue(getIndex());
717 }
718 
719 Speculation::Speculatability DimOp::getSpeculatability() {
720  auto constantIndex = getConstantIndex();
721  if (!constantIndex)
723 
724  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
725  if (!rankedSourceType)
727 
728  if (rankedSourceType.getRank() <= constantIndex)
730 
732 }
733 
734 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
735  // All forms of folding require a known index.
736  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
737  if (!index)
738  return {};
739 
740  // Folding for unranked types (UnrankedTensorType) is not supported.
741  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
742  if (!tensorType)
743  return {};
744 
745  // Out of bound indices produce undefined behavior but are still valid IR.
746  // Don't choke on them.
747  int64_t indexVal = index.getInt();
748  if (indexVal < 0 || indexVal >= tensorType.getRank())
749  return {};
750 
751  // Fold if the shape extent along the given index is known.
752  if (!tensorType.isDynamicDim(index.getInt())) {
753  Builder builder(getContext());
754  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
755  }
756 
757  Operation *definingOp = getSource().getDefiningOp();
758 
759  // Fold dim to the operand of tensor.generate.
760  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
761  auto resultType =
762  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
763  // The case where the type encodes the size of the dimension is handled
764  // above.
765  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
766 
767  // Find the operand of the fromElements that corresponds to this index.
768  auto dynExtents = fromElements.getDynamicExtents().begin();
769  for (auto dim : resultType.getShape().take_front(index.getInt()))
770  if (ShapedType::isDynamic(dim))
771  dynExtents++;
772 
773  return Value{*dynExtents};
774  }
775 
776  // The size at the given index is now known to be a dynamic size.
777  unsigned unsignedIndex = index.getValue().getZExtValue();
778 
779  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
780  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
781  // `resolve-shaped-type-result-dims` pass.
782  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
783  sliceOp.isDynamicSize(unsignedIndex)) {
784  return {sliceOp.getDynamicSize(unsignedIndex)};
785  }
786  }
787 
788  // dim(cast) -> dim
789  if (succeeded(foldTensorCast(*this)))
790  return getResult();
791 
792  return {};
793 }
794 
795 namespace {
796 /// Fold dim of a cast into the dim of the source of the tensor cast.
797 struct DimOfCastOp : public OpRewritePattern<DimOp> {
799 
800  LogicalResult matchAndRewrite(DimOp dimOp,
801  PatternRewriter &rewriter) const override {
802  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
803  if (!castOp)
804  return failure();
805  Value newSource = castOp.getOperand();
806  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
807  return success();
808  }
809 };
810 
811 /// Fold dim of a destination passing style op into the dim of the corresponding
812 /// init.
813 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
815 
816  LogicalResult matchAndRewrite(DimOp dimOp,
817  PatternRewriter &rewriter) const override {
818  auto source = dimOp.getSource();
819  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
820  if (!destOp)
821  return failure();
822 
823  auto resultIndex = source.cast<OpResult>().getResultNumber();
824  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
825 
826  rewriter.modifyOpInPlace(
827  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
828  return success();
829  }
830 };
831 
832 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
833 /// operand.
834 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
836 
837  LogicalResult matchAndRewrite(DimOp dim,
838  PatternRewriter &rewriter) const override {
839  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
840 
841  if (!reshape)
842  return failure();
843 
844  // Since tensors are immutable we don't need to worry about where to place
845  // the extract call
846  rewriter.setInsertionPointAfter(dim);
847  Location loc = dim.getLoc();
848  Value extract =
849  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
850  if (extract.getType() != dim.getType())
851  extract =
852  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
853  rewriter.replaceOp(dim, extract);
854  return success();
855  }
856 };
857 } // namespace
858 
859 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
860  MLIRContext *context) {
861  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // EmptyOp
866 //===----------------------------------------------------------------------===//
867 
868 void EmptyOp::build(OpBuilder &builder, OperationState &result,
869  ArrayRef<int64_t> staticShape, Type elementType,
870  Attribute encoding) {
871  assert(all_of(staticShape,
872  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
873  "expected only static sizes");
874  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
875 }
876 
877 void EmptyOp::build(OpBuilder &builder, OperationState &result,
878  ArrayRef<int64_t> staticShape, Type elementType,
879  ValueRange dynamicSizes, Attribute encoding) {
880  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
881  build(builder, result, tensorType, dynamicSizes);
882 }
883 
884 void EmptyOp::build(OpBuilder &builder, OperationState &result,
885  ArrayRef<OpFoldResult> sizes, Type elementType,
886  Attribute encoding) {
887  SmallVector<int64_t> staticShape;
888  SmallVector<Value> dynamicSizes;
889  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
890  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
891 }
892 
894  if (getType().getNumDynamicDims() !=
895  static_cast<int64_t>(getDynamicSizes().size()))
896  return emitOpError("incorrect number of dynamic sizes, has ")
897  << getDynamicSizes().size() << ", expected "
898  << getType().getNumDynamicDims();
899  return success();
900 }
901 
904  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
905  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
906  unsigned ctr = 0;
907  for (int64_t i = 0; i < getType().getRank(); ++i) {
908  if (getType().isDynamicDim(i)) {
909  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
910  } else {
911  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
912  }
913  }
914  return success();
915 }
916 
917 Value EmptyOp::getDynamicSize(unsigned idx) {
918  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
919  unsigned ctr = 0;
920  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
921  if (getType().isDynamicDim(i))
922  ++ctr;
923  return getDynamicSizes()[ctr];
924 }
925 
928  unsigned ctr = 0;
929  OpBuilder b(getContext());
930  for (int64_t i = 0; i < getType().getRank(); ++i) {
931  if (getType().isDynamicDim(i)) {
932  result.push_back(getDynamicSizes()[ctr++]);
933  } else {
934  result.push_back(b.getIndexAttr(getType().getShape()[i]));
935  }
936  }
937  return result;
938 }
939 
940 namespace {
941 /// Change the type of the result of a `tensor.empty` by making the result
942 /// type statically sized along dimensions that in the original operation were
943 /// defined as dynamic, but the size was defined using a `constant` op. For
944 /// example
945 ///
946 /// %c5 = arith.constant 5: index
947 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
948 ///
949 /// to
950 ///
951 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
952 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
954 
955  LogicalResult matchAndRewrite(EmptyOp op,
956  PatternRewriter &rewriter) const override {
957  SmallVector<Value> foldedDynamicSizes;
958  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
959  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
960 
961  // Stop here if no dynamic size was promoted to static.
962  if (foldedTensorType == op.getType())
963  return failure();
964 
965  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
966  foldedDynamicSizes);
967  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
968  return success();
969  }
970 };
971 
972 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
974 
975  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
976  PatternRewriter &rewriter) const override {
977  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
978  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
979  if (!emptyTensorOp || !maybeConstantIndex)
980  return failure();
981  if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
982  return failure();
983  rewriter.replaceOp(dimOp,
984  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
985  return success();
986  }
987 };
988 
989 /// Canonicalize
990 ///
991 /// ```mlir
992 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
993 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
994 /// ```
995 ///
996 /// into
997 ///
998 /// ```mlir
999 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1000 /// ```
1001 ///
1002 /// This assumes the input program is correct in terms of its shape. So it is
1003 /// safe to assume that `%d0` is in fact 4.
1004 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1006 
1007  LogicalResult matchAndRewrite(CastOp castOp,
1008  PatternRewriter &rewriter) const override {
1009  if (!canFoldIntoProducerOp(castOp))
1010  return failure();
1011  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1012  if (!producer)
1013  return failure();
1014 
1015  auto resultType =
1016  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1017  ArrayRef<int64_t> resultShape = resultType.getShape();
1018  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1019  SmallVector<OpFoldResult> newMixedSizes;
1020  newMixedSizes.reserve(currMixedSizes.size());
1021  assert(resultShape.size() == currMixedSizes.size() &&
1022  "mismatch in result shape and sizes of empty op");
1023  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1024  int64_t newDim = std::get<0>(it);
1025  OpFoldResult currDim = std::get<1>(it);
1026  // Case 1: The empty tensor dim is static. Check that the tensor cast
1027  // result dim matches.
1028  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1029  if (ShapedType::isDynamic(newDim) ||
1030  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1031  // Something is off, the cast result shape cannot be more dynamic
1032  // than the empty tensor result shape (enforced by
1033  // `canFoldIntoProducer`). Abort for now.
1034  return rewriter.notifyMatchFailure(
1035  producer, "mismatch in static value of shape of empty tensor "
1036  "result and cast result");
1037  }
1038  newMixedSizes.push_back(attr);
1039  continue;
1040  }
1041 
1042  // Case 2 : The tensor cast shape is static, but empty tensor result
1043  // shape is dynamic.
1044  if (!ShapedType::isDynamic(newDim)) {
1045  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1046  continue;
1047  }
1048 
1049  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1050  // shape is dynamic. Use the dynamic value from the empty tensor op.
1051  newMixedSizes.push_back(currDim);
1052  }
1053 
1054  // TODO: Do not drop tensor encoding.
1055  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1056  resultType.getElementType());
1057  return success();
1058  }
1059 };
1060 
1061 } // namespace
1062 
1063 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1064  MLIRContext *context) {
1065  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1066  ReplaceEmptyTensorStaticShapeDims>(context);
1067 }
1068 
1069 /// Try to remove a tensor operation if it would only reshape a constant.
1070 /// Removes the op and replaces the constant with a new constant of the result
1071 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1072 /// splat value matches the value in the attribute.
1073 static OpFoldResult
1075  std::optional<Attribute> cst = std::nullopt) {
1076  if (source && source.isSplat() && result.hasStaticShape() &&
1077  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1078  return source.resizeSplat(result);
1079 
1080  return {};
1081 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // ExtractOp
1085 //===----------------------------------------------------------------------===//
1086 
1087 namespace {
1088 
1089 /// Canonicalizes the pattern of the form
1090 ///
1091 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1092 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1093 ///
1094 /// to
1095 ///
1096 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1097 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1099 
1100  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1101  PatternRewriter &rewriter) const final {
1102  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1103  if (!tensorCast)
1104  return failure();
1105  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1106  return failure();
1107  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1108  extract, tensorCast.getSource(), extract.getIndices());
1109  return success();
1110  }
1111 };
1112 
1113 } // namespace
1114 
1115 void ExtractOp::getAsmResultNames(
1116  function_ref<void(Value, StringRef)> setNameFn) {
1117  setNameFn(getResult(), "extracted");
1118 }
1119 
1121  // Verify the # indices match if we have a ranked type.
1122  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1123  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1124  return emitOpError("incorrect number of indices for extract_element");
1125  return success();
1126 }
1127 
1128 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1129  // If this is a splat elements attribute, simply return the value. All of
1130  // the elements of a splat attribute are the same.
1131  if (Attribute tensor = adaptor.getTensor())
1132  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1133  return splatTensor.getSplatValue<Attribute>();
1134 
1135  // Collect the constant indices into the tensor.
1136  SmallVector<uint64_t, 8> indices;
1137  for (Attribute indice : adaptor.getIndices()) {
1138  if (!indice || !llvm::isa<IntegerAttr>(indice))
1139  return {};
1140  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1141  }
1142 
1143  // Fold extract(from_elements(...)).
1144  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1145  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1146  auto rank = tensorType.getRank();
1147  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1148  "rank mismatch");
1149  int flatIndex = 0;
1150  int stride = 1;
1151  for (int i = rank - 1; i >= 0; --i) {
1152  flatIndex += indices[i] * stride;
1153  stride *= tensorType.getDimSize(i);
1154  }
1155  // Prevent out of bounds accesses. This can happen in invalid code that
1156  // will never execute.
1157  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1158  flatIndex < 0)
1159  return {};
1160  return fromElementsOp.getElements()[flatIndex];
1161  }
1162 
1163  // If this is an elements attribute, query the value at the given indices.
1164  if (Attribute tensor = adaptor.getTensor()) {
1165  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1166  if (elementsAttr && elementsAttr.isValidIndex(indices))
1167  return elementsAttr.getValues<Attribute>()[indices];
1168  }
1169 
1170  return {};
1171 }
1172 
1173 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1174  MLIRContext *context) {
1175  results.add<ExtractFromTensorCast>(context);
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // FromElementsOp
1180 //===----------------------------------------------------------------------===//
1181 
1182 void FromElementsOp::getAsmResultNames(
1183  function_ref<void(Value, StringRef)> setNameFn) {
1184  setNameFn(getResult(), "from_elements");
1185 }
1186 
1187 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1188  ValueRange elements) {
1189  assert(!elements.empty() && "expected at least one element");
1190  Type resultType = RankedTensorType::get(
1191  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1192  build(builder, result, resultType, elements);
1193 }
1194 
1195 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1196  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1197  return DenseElementsAttr::get(getType(), adaptor.getElements());
1198  return {};
1199 }
1200 
1201 namespace {
1202 
1203 // Pushes the index_casts that occur before extractions to after the extract.
1204 // This minimizes type conversion in some cases and enables the extract
1205 // canonicalizer. This changes:
1206 //
1207 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1208 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1209 //
1210 // to the following:
1211 //
1212 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1213 // %cast = arith.index_cast %extract : i32 to index
1214 //
1215 // to just %element.
1216 //
1217 // Consider expanding this to a template and handle all tensor cast
1218 // operations.
1219 struct ExtractElementFromIndexCast
1220  : public OpRewritePattern<tensor::ExtractOp> {
1222 
1223  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1224  PatternRewriter &rewriter) const final {
1225  Location loc = extract.getLoc();
1226  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1227  if (!indexCast)
1228  return failure();
1229 
1230  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1231 
1232  auto newExtract = rewriter.create<tensor::ExtractOp>(
1233  loc, elementTy, indexCast.getIn(), extract.getIndices());
1234 
1235  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1236  newExtract);
1237 
1238  return success();
1239  }
1240 };
1241 
1242 } // namespace
1243 
1244 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1245  MLIRContext *context) {
1246  results.add<ExtractElementFromIndexCast>(context);
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // GatherOp
1251 //===----------------------------------------------------------------------===//
1252 
1253 void GatherOp::getAsmResultNames(
1254  function_ref<void(Value, StringRef)> setNameFn) {
1255  setNameFn(getResult(), "gather");
1256 }
1257 
1258 /// Return the inferred result type for a gatherOp where:
1259 /// - sourceType is the type of the source tensor gathered from
1260 /// - indicesType is the type of the indices used to gather
1261 /// - gatherDims are the dims along which the gather occurs.
1262 /// Return a full rank or ranked-reduced variant of the type depending on
1263 /// the value of rankReduced.
1264 ///
1265 /// The leading dimensions of the index tensor give the result tensor its
1266 /// leading dimensions.
1267 /// The trailing dimensions of the result tensor are obtained from the source
1268 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1269 /// rankedReduced is false), or skipping them (otherwise).
1270 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1271  RankedTensorType indicesType,
1272  ArrayRef<int64_t> gatherDims,
1273  bool rankReduced) {
1274  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1275  resultShape.reserve(resultShape.size() + sourceType.getRank());
1276  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1277  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1278  if (!rankReduced)
1279  resultShape.push_back(1);
1280  continue;
1281  }
1282  resultShape.push_back(sourceType.getDimSize(idx));
1283  }
1284  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1285 }
1286 
1287 static LogicalResult
1289  StringRef gatherOrScatter, StringRef sourceOrDest) {
1290  if (dims.empty())
1291  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1292 
1293  int64_t numGatherDims = dims.size();
1294  if (numGatherDims > rank)
1295  return op->emitOpError(gatherOrScatter)
1296  << "_dims overflow " << sourceOrDest << " rank";
1297  for (int64_t val : dims) {
1298  if (val < 0)
1299  return op->emitOpError(gatherOrScatter)
1300  << "_dims value must be non-negative";
1301  if (val >= rank)
1302  return op->emitOpError(gatherOrScatter)
1303  << "_dims value must be smaller than " << sourceOrDest << " rank";
1304  }
1305  for (int64_t i = 1; i < numGatherDims; ++i) {
1306  if (dims[i - 1] >= dims[i])
1307  return op->emitOpError(gatherOrScatter)
1308  << "_dims values must be strictly increasing";
1309  }
1310  return success();
1311 }
1312 
1314  int64_t sourceRank = getSourceType().getRank();
1315  ArrayRef<int64_t> gatherDims = getGatherDims();
1316  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1317  "gather", "source")))
1318  return failure();
1319 
1320  RankedTensorType expectedResultType = GatherOp::inferResultType(
1321  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1322  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1323  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1324  if (getResultType() != expectedResultType &&
1325  getResultType() != expectedRankReducedResultType) {
1326  return emitOpError("result type "
1327  "mismatch: "
1328  "expected ")
1329  << expectedResultType << " or its rank-reduced variant "
1330  << expectedRankReducedResultType << " (got: " << getResultType()
1331  << ")";
1332  }
1333 
1334  return success();
1335 }
1336 
1337 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1338  if (OpFoldResult reshapedSource = reshapeConstantSource(
1339  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1340  getResult().getType()))
1341  return reshapedSource;
1342  return {};
1343 }
1344 
1345 //===----------------------------------------------------------------------===//
1346 // InsertOp
1347 //===----------------------------------------------------------------------===//
1348 
1349 void InsertOp::getAsmResultNames(
1350  function_ref<void(Value, StringRef)> setNameFn) {
1351  setNameFn(getResult(), "inserted");
1352 }
1353 
1355  // Verify the # indices match if we have a ranked type.
1356  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1357  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1358  return emitOpError("incorrect number of indices");
1359  return success();
1360 }
1361 
1362 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1363  Attribute scalar = adaptor.getScalar();
1364  Attribute dest = adaptor.getDest();
1365  if (scalar && dest)
1366  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1367  if (scalar == splatDest.getSplatValue<Attribute>())
1368  return dest;
1369  return {};
1370 }
1371 
1372 //===----------------------------------------------------------------------===//
1373 // GenerateOp
1374 //===----------------------------------------------------------------------===//
1375 
1376 void GenerateOp::getAsmResultNames(
1377  function_ref<void(Value, StringRef)> setNameFn) {
1378  setNameFn(getResult(), "generated");
1379 }
1380 
1382  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1383  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1384  int idx = 0;
1385  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1386  if (getType().isDynamicDim(dim)) {
1387  reifiedReturnShapes[0][dim] = getOperand(idx++);
1388  } else {
1389  reifiedReturnShapes[0][dim] =
1390  builder.getIndexAttr(getType().getDimSize(dim));
1391  }
1392  }
1393  return success();
1394 }
1395 
1397  // Ensure that the tensor type has as many dynamic dimensions as are
1398  // specified by the operands.
1399  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1400  if (getNumOperands() != resultType.getNumDynamicDims())
1401  return emitError("must have as many index operands as dynamic extents "
1402  "in the result type");
1403  return success();
1404 }
1405 
1406 LogicalResult GenerateOp::verifyRegions() {
1407  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1408  // Ensure that region arguments span the index space.
1409  if (!llvm::all_of(getBody().getArgumentTypes(),
1410  [](Type ty) { return ty.isIndex(); }))
1411  return emitError("all body arguments must be index");
1412  if (getBody().getNumArguments() != resultTy.getRank())
1413  return emitError("must have one body argument per input dimension");
1414 
1415  // Ensure that the region yields an element of the right type.
1416  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1417 
1418  if (yieldOp.getValue().getType() != resultTy.getElementType())
1419  return emitOpError(
1420  "body must be terminated with a `yield` operation of the tensor "
1421  "element type");
1422 
1423  return success();
1424 }
1425 
1426 void GenerateOp::build(
1427  OpBuilder &b, OperationState &result, Type resultTy,
1428  ValueRange dynamicExtents,
1429  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1430  build(b, result, resultTy, dynamicExtents);
1431 
1432  // Build and populate body.
1433  OpBuilder::InsertionGuard guard(b);
1434  Region *bodyRegion = result.regions.front().get();
1435  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1436  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1437  SmallVector<Location, 2> argumentLocs(rank, result.location);
1438  Block *bodyBlock =
1439  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1440  bodyBuilder(b, result.location, bodyBlock->getArguments());
1441 }
1442 
1443 namespace {
1444 
1445 /// Canonicalizes tensor.generate operations with a constant
1446 /// operand into the equivalent operation with the operand expressed in the
1447 /// result type, instead. We also insert a type cast to make sure that the
1448 /// resulting IR is still well-typed.
1449 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1451 
1452  LogicalResult matchAndRewrite(GenerateOp generateOp,
1453  PatternRewriter &rewriter) const final {
1454  SmallVector<Value> foldedDynamicSizes;
1455  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1456  generateOp.getType(), generateOp.getDynamicExtents(),
1457  foldedDynamicSizes);
1458 
1459  // Stop here if no dynamic size was promoted to static.
1460  if (foldedTensorType == generateOp.getType())
1461  return failure();
1462 
1463  auto loc = generateOp.getLoc();
1464  auto newOp =
1465  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1466  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1467  newOp.getBody().begin());
1468  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1469  generateOp.getType(), newOp);
1470  return success();
1471  }
1472 };
1473 
1474 /// Canonicalizes the pattern of the form
1475 ///
1476 /// %tensor = tensor.generate %x {
1477 /// ^bb0(%arg0: index):
1478 /// <computation>
1479 /// yield %1 : index
1480 /// } : tensor<?xindex>
1481 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1482 ///
1483 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1484 /// tensor.generate operation has no side-effects.
1485 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1487 
1488  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1489  PatternRewriter &rewriter) const final {
1490  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1491  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1492  return failure();
1493 
1494  IRMapping mapping;
1495  Block *body = &tensorFromElements.getBody().front();
1496  mapping.map(body->getArguments(), extract.getIndices());
1497  for (auto &op : body->without_terminator())
1498  rewriter.clone(op, mapping);
1499 
1500  auto yield = cast<YieldOp>(body->getTerminator());
1501 
1502  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1503  return success();
1504  }
1505 };
1506 
1507 } // namespace
1508 
1509 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1510  MLIRContext *context) {
1511  // TODO: Move extract pattern to tensor::ExtractOp.
1512  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // RankOp
1517 //===----------------------------------------------------------------------===//
1518 
1519 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1520  setNameFn(getResult(), "rank");
1521 }
1522 
1523 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1524  // Constant fold rank when the rank of the operand is known.
1525  auto type = getOperand().getType();
1526  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1527  if (shapedType && shapedType.hasRank())
1528  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1529  return IntegerAttr();
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // ReshapeOp
1534 //===----------------------------------------------------------------------===//
1535 
1536 void ReshapeOp::getAsmResultNames(
1537  function_ref<void(Value, StringRef)> setNameFn) {
1538  setNameFn(getResult(), "reshape");
1539 }
1540 
1541 static int64_t getNumElements(ShapedType type) {
1542  int64_t numElements = 1;
1543  for (auto dim : type.getShape())
1544  numElements *= dim;
1545  return numElements;
1546 }
1547 
1549  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1550  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1551 
1552  if (operandType.getElementType() != resultType.getElementType())
1553  return emitOpError("element types of source and destination tensor "
1554  "types should be the same");
1555 
1556  int64_t shapeSize =
1557  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1558  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1559  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1560 
1561  if (resultRankedType) {
1562  if (operandRankedType && resultRankedType.hasStaticShape() &&
1563  operandRankedType.hasStaticShape()) {
1564  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1565  return emitOpError("source and destination tensor should have the "
1566  "same number of elements");
1567  }
1568  if (ShapedType::isDynamic(shapeSize))
1569  return emitOpError("cannot use shape operand with dynamic length to "
1570  "reshape to statically-ranked tensor type");
1571  if (shapeSize != resultRankedType.getRank())
1572  return emitOpError(
1573  "length of shape operand differs from the result's tensor rank");
1574  }
1575  return success();
1576 }
1577 
1578 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1579  if (OpFoldResult reshapedSource = reshapeConstantSource(
1580  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1581  getResult().getType()))
1582  return reshapedSource;
1583  return {};
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 // Reassociative reshape ops
1588 //===----------------------------------------------------------------------===//
1589 
1590 void CollapseShapeOp::getAsmResultNames(
1591  function_ref<void(Value, StringRef)> setNameFn) {
1592  setNameFn(getResult(), "collapsed");
1593 }
1594 
1595 void ExpandShapeOp::getAsmResultNames(
1596  function_ref<void(Value, StringRef)> setNameFn) {
1597  setNameFn(getResult(), "expanded");
1598 }
1599 
1600 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1601  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1602  "invalid resultDim");
1603  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1604  if (llvm::is_contained(it.value(), resultDim))
1605  return it.index();
1606  llvm_unreachable("could not find reassociation group");
1607 }
1608 
1609 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1610  return getSymbolLessAffineMaps(getReassociationExprs());
1611 }
1612 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1614  getReassociationIndices());
1615 }
1616 
1617 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1618  return getSymbolLessAffineMaps(getReassociationExprs());
1619 }
1620 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1622  getReassociationIndices());
1623 }
1624 
1625 RankedTensorType CollapseShapeOp::inferCollapsedType(
1626  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1627  return inferCollapsedType(
1629  type.getContext(), reassociation)));
1630 }
1631 
1632 /// Compute the RankedTensorType obtained by applying `reassociation` to
1633 /// `type`.
1634 RankedTensorType
1635 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1636  ArrayRef<AffineMap> reassociation) {
1637  auto shape = type.getShape();
1638  SmallVector<int64_t, 4> newShape;
1639  newShape.reserve(reassociation.size());
1640 
1641  // Use the fact that reassociation is valid to simplify the logic: only use
1642  // each map's rank.
1643  assert(isReassociationValid(reassociation) && "invalid reassociation");
1644  unsigned currentDim = 0;
1645  for (AffineMap m : reassociation) {
1646  unsigned dim = m.getNumResults();
1647  auto band = shape.slice(currentDim, dim);
1648  int64_t size = 1;
1649  if (llvm::is_contained(band, ShapedType::kDynamic))
1650  size = ShapedType::kDynamic;
1651  else
1652  for (unsigned d = 0; d < dim; ++d)
1653  size *= shape[currentDim + d];
1654  newShape.push_back(size);
1655  currentDim += dim;
1656  }
1657 
1658  return RankedTensorType::get(newShape, type.getElementType());
1659 }
1660 
1661 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1662  ArrayRef<ReassociationIndices> reassociation,
1663  ArrayRef<NamedAttribute> attrs) {
1664  auto resultType = inferCollapsedType(
1665  llvm::cast<RankedTensorType>(src.getType()),
1667  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1668  build(b, result, resultType, src, attrs);
1669  result.addAttribute(getReassociationAttrStrName(),
1670  getReassociationIndicesAttribute(b, reassociation));
1671 }
1672 
1673 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1674  TensorReshapeOp, ExpandShapeOp>::value>
1675 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1676  RankedTensorType expandedType,
1677  RankedTensorType collapsedType) {
1678  if (failed(
1679  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1680  return failure();
1681 
1682  auto maps = op.getReassociationMaps();
1683  RankedTensorType expectedType =
1684  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1685  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1686  return op.emitOpError("expected collapsed type to be ")
1687  << expectedType << ", but got " << collapsedType;
1688  return success();
1689 }
1690 
1692  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
1693 }
1694 
1696  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1697 }
1698 
1699 namespace {
1700 /// Reshape of a splat constant can be replaced with a constant of the result
1701 /// type.
1702 template <typename TensorReshapeOp>
1703 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1705  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1706  PatternRewriter &rewriter) const override {
1707  DenseElementsAttr attr;
1708  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1709  return failure();
1710  if (!attr || !attr.isSplat())
1711  return failure();
1713  reshapeOp.getResultType(), attr.getRawData());
1714  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1715  return success();
1716  }
1717 };
1718 
1719 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1720 template <typename TensorReshapeOp>
1721 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1722 public:
1724 
1725  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1726  PatternRewriter &rewriter) const override {
1727  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1728  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1729  return failure();
1730 
1731  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1732  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1733  return success();
1734  }
1735 };
1736 
1737 /// Reshape of a FromElements can be replaced with a FromElements of the
1738 /// result type
1739 template <typename TensorReshapeOp>
1740 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1742  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1743  PatternRewriter &rewriter) const override {
1744  auto fromElements =
1745  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1746  if (!fromElements)
1747  return failure();
1748 
1749  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1750 
1751  if (!shapedTy.hasStaticShape())
1752  return failure();
1753 
1754  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1755  fromElements.getElements());
1756  return success();
1757  }
1758 };
1759 
1760 // Fold CastOp into CollapseShapeOp when adding static information.
1761 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1763 
1764  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1765  PatternRewriter &rewriter) const override {
1766  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1767  if (!tensor::canFoldIntoConsumerOp(castOp))
1768  return failure();
1769 
1770  RankedTensorType srcType =
1771  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1772  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1773  srcType, collapseShapeOp.getReassociationMaps());
1774 
1775  if (newResultType == collapseShapeOp.getResultType()) {
1776  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1777  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1778  });
1779  } else {
1780  auto newOp = rewriter.create<CollapseShapeOp>(
1781  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1782  collapseShapeOp.getReassociation());
1783  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1784  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1785  }
1786  return success();
1787  }
1788 };
1789 
1790 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1792 
1793  LogicalResult matchAndRewrite(DimOp dimOp,
1794  PatternRewriter &rewriter) const override {
1795  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1796  if (!expandShapeOp)
1797  return failure();
1798 
1799  // Only constant dimension values are supported.
1800  std::optional<int64_t> dim = dimOp.getConstantIndex();
1801  if (!dim.has_value())
1802  return failure();
1803 
1804  // Skip static dims. These are folded to constant ops.
1805  RankedTensorType resultType = expandShapeOp.getResultType();
1806  if (!resultType.isDynamicDim(*dim))
1807  return failure();
1808 
1809  // Find reassociation group that contains this result dimension.
1810  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1811 
1812  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1813  // ExpandShapeOp would be ambiguous.)
1814  int64_t product = 1;
1815  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1816  for (int64_t d : grp) {
1817  if (d != dim) {
1818  assert(!resultType.isDynamicDim(d) && "expected static dim");
1819  product *= resultType.getDimSize(d);
1820  }
1821  }
1822 
1823  // result dim size = src dim size / (product(other dims in reassoc group))
1824  Value srcDimSz =
1825  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1826  AffineExpr expr;
1827  bindSymbols(dimOp.getContext(), expr);
1828  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1829  dimOp, expr.floorDiv(product), srcDimSz);
1830  return success();
1831  }
1832 };
1833 
1834 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
1836 
1837  LogicalResult matchAndRewrite(DimOp dimOp,
1838  PatternRewriter &rewriter) const override {
1839  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1840  if (!collapseShapeOp)
1841  return failure();
1842 
1843  // Only constant dimension values are supported.
1844  std::optional<int64_t> dim = dimOp.getConstantIndex();
1845  if (!dim.has_value())
1846  return failure();
1847 
1848  // Skip static dims. These are folded to constant ops.
1849  RankedTensorType resultType = collapseShapeOp.getResultType();
1850  if (!resultType.isDynamicDim(*dim))
1851  return failure();
1852 
1853  // Get reassociation group of the result dimension.
1854  ReassociationIndices group =
1855  collapseShapeOp.getReassociationIndices()[*dim];
1856 
1857  // result dim size = product(dims in reassoc group)
1858  SmallVector<Value> srcDimSizes;
1861  for (const auto &it : llvm::enumerate(group)) {
1862  srcDimSizes.push_back(rewriter.create<DimOp>(
1863  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1864  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
1865  product = product ? product * syms.back() : syms.back();
1866  }
1867  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
1868  srcDimSizes);
1869  return success();
1870  }
1871 };
1872 } // namespace
1873 
1874 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1875  MLIRContext *context) {
1878  FoldReshapeWithConstant<ExpandShapeOp>,
1879  FoldReshapeWithSplat<ExpandShapeOp>,
1880  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1881  FoldDimOfCollapseShape>(context);
1882 }
1883 
1884 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1885  MLIRContext *context) {
1886  results
1889  FoldReshapeWithConstant<CollapseShapeOp>,
1890  FoldReshapeWithSplat<CollapseShapeOp>,
1891  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1892  context);
1893 }
1894 
1895 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1896  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
1897  adaptor.getOperands());
1898 }
1899 
1900 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1901  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
1902  adaptor.getOperands());
1903 }
1904 
1905 //===----------------------------------------------------------------------===//
1906 // ExtractSliceOp
1907 //===----------------------------------------------------------------------===//
1908 
1909 void ExtractSliceOp::getAsmResultNames(
1910  function_ref<void(Value, StringRef)> setNameFn) {
1911  setNameFn(getResult(), "extracted_slice");
1912 }
1913 
1914 /// An extract_slice result type can be inferred, when it is not
1915 /// rank-reduced, from the source type and the static representation of
1916 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
1917 RankedTensorType ExtractSliceOp::inferResultType(
1918  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
1919  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
1920  // An extract_slice op may specify only a leading subset of offset/sizes/
1921  // strides in which case we complete with offset=0, sizes from memref type
1922  // and strides=1.
1923  assert(static_cast<int64_t>(staticSizes.size()) ==
1924  sourceTensorType.getRank() &&
1925  "unexpected staticSizes not equal to rank of source");
1926  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
1927 }
1928 
1929 RankedTensorType ExtractSliceOp::inferResultType(
1930  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
1932  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1933  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1934  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1935  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1936  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1937  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
1938  staticSizes, staticStrides);
1939 }
1940 
1941 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
1942 /// number of sizes), drop as many size 1 as needed to produce an inferred
1943 /// type with the desired rank.
1944 ///
1945 /// Note that there may be multiple ways to compute this rank-reduced type:
1946 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
1947 ///
1948 /// To disambiguate, this function always drops the first 1 sizes occurrences.
1949 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1950  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1951  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1952  ArrayRef<int64_t> strides) {
1953  // Type inferred in the absence of rank-reducing behavior.
1954  auto inferredType = llvm::cast<RankedTensorType>(
1955  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
1956  int rankDiff = inferredType.getRank() - desiredResultRank;
1957  if (rankDiff > 0) {
1958  auto shape = inferredType.getShape();
1959  llvm::SmallBitVector dimsToProject =
1960  getPositionsOfShapeOne(rankDiff, shape);
1961  SmallVector<int64_t> projectedShape;
1962  // Best effort rank-reducing: drop 1s in order.
1963  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1964  if (!dimsToProject.test(pos))
1965  projectedShape.push_back(shape[pos]);
1966  inferredType =
1967  RankedTensorType::get(projectedShape, inferredType.getElementType());
1968  }
1969  return inferredType;
1970 }
1971 
1972 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1973  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1975  ArrayRef<OpFoldResult> strides) {
1976  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1977  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1978  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1979  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1980  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1981  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1982  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1983  staticStrides);
1984 }
1985 
1986 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1987 /// result type. If the type passed is nullptr, it is inferred.
1988 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1989  RankedTensorType resultType, Value source,
1990  ArrayRef<OpFoldResult> offsets,
1991  ArrayRef<OpFoldResult> sizes,
1992  ArrayRef<OpFoldResult> strides,
1993  ArrayRef<NamedAttribute> attrs) {
1994  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1995  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1996  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1997  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1998  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1999  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2000  // Structuring implementation this way avoids duplication between builders.
2001  if (!resultType) {
2002  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2003  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2004  }
2005  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2006  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2007  b.getDenseI64ArrayAttr(staticSizes),
2008  b.getDenseI64ArrayAttr(staticStrides));
2009  result.addAttributes(attrs);
2010 }
2011 
2012 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2013 /// result type.
2014 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2015  ArrayRef<OpFoldResult> offsets,
2016  ArrayRef<OpFoldResult> sizes,
2017  ArrayRef<OpFoldResult> strides,
2018  ArrayRef<NamedAttribute> attrs) {
2019  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2020 }
2021 
2022 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2023 /// a Range vector.
2024 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2025  ArrayRef<Range> ranges,
2026  ArrayRef<NamedAttribute> attrs) {
2027  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2028  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2029 }
2030 
2031 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2032 /// the type passed is nullptr, it is inferred.
2033 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2034  RankedTensorType resultType, Value source,
2035  ValueRange offsets, ValueRange sizes,
2036  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2037  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2038  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2039  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2040  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2041  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2042  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2043  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2044 }
2045 
2046 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2047 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2048  ValueRange offsets, ValueRange sizes,
2049  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2050  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2051 }
2052 
2054  Operation *op,
2055  RankedTensorType expectedType) {
2056  switch (result) {
2058  return success();
2060  return op->emitError("expected rank to be smaller or equal to ")
2061  << "the other rank. ";
2063  return op->emitError("expected type to be ")
2064  << expectedType << " or a rank-reduced version. (size mismatch) ";
2066  return op->emitError("expected element type to be ")
2067  << expectedType.getElementType();
2068  default:
2069  llvm_unreachable("unexpected extract_slice op verification result");
2070  }
2071 }
2072 
2073 /// Verifier for ExtractSliceOp.
2075  // Verify result type against inferred type.
2076  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2077  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2078  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2079  return produceSliceErrorMsg(result, *this, expectedType);
2080 }
2081 
2082 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2084 }
2085 
2087 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2088  ArrayRef<int64_t> desiredShape) {
2089  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2090  assert(sourceTensorType && "not a ranked tensor type");
2091  auto sourceShape = sourceTensorType.getShape();
2092  if (sourceShape.equals(desiredShape))
2093  return value;
2094  auto maybeRankReductionMask =
2095  mlir::computeRankReductionMask(sourceShape, desiredShape);
2096  if (!maybeRankReductionMask)
2097  return failure();
2099  b, loc, value,
2100  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2101 }
2102 
2104  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2105  reifiedReturnShapes.resize(1);
2106  reifiedReturnShapes[0].reserve(getType().getRank());
2108  llvm::SmallBitVector droppedDims = getDroppedDims();
2109  for (const auto &size : enumerate(mixedSizes)) {
2110  if (droppedDims.test(size.index()))
2111  continue;
2112  reifiedReturnShapes[0].push_back(size.value());
2113  }
2114  return success();
2115 }
2116 
2117 namespace {
2118 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2119 /// This essentially pushes memref_cast past its consuming slice when
2120 /// `canFoldIntoConsumerOp` is true.
2121 ///
2122 /// Example:
2123 /// ```
2124 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2125 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2126 /// tensor<3x4xf32>
2127 /// ```
2128 /// is rewritten into:
2129 /// ```
2130 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2131 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2132 /// ```
2133 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2134 public:
2136 
2137  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2138  PatternRewriter &rewriter) const override {
2139  // Any constant operand, just return to let the constant folder kick in.
2140  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2141  return matchPattern(operand, matchConstantIndex());
2142  }))
2143  return failure();
2144 
2145  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2146  if (!castOp)
2147  return failure();
2148 
2149  if (!canFoldIntoConsumerOp(castOp))
2150  return failure();
2151 
2152  // Create folded extract.
2153  Location loc = sliceOp.getLoc();
2154  Value newResult = rewriter.create<ExtractSliceOp>(
2155  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2156  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2157  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2158  if (newResult.getType() != sliceOp.getType())
2159  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2160  rewriter.replaceOp(sliceOp, newResult);
2161  return success();
2162  }
2163 };
2164 
2165 /// Slice elements from `values` into `outValues`. `counts` represents the
2166 /// numbers of elements to stride in the original values for each dimension.
2167 /// The output values can be used to construct a DenseElementsAttr.
2168 template <typename IterTy, typename ElemTy>
2169 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2170  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2171  ArrayRef<int64_t> strides,
2172  llvm::SmallVectorImpl<ElemTy> *outValues) {
2173  assert(offsets.size() == sizes.size());
2174  assert(offsets.size() == strides.size());
2175  if (offsets.empty())
2176  return;
2177 
2178  int64_t offset = offsets.front();
2179  int64_t size = sizes.front();
2180  int64_t stride = strides.front();
2181  if (offsets.size() == 1) {
2182  for (int64_t i = 0; i < size; ++i, offset += stride)
2183  outValues->push_back(*(values + offset));
2184 
2185  return;
2186  }
2187 
2188  for (int64_t i = 0; i < size; ++i, offset += stride) {
2189  auto begin = values + offset * counts.front();
2190  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2191  offsets.drop_front(), sizes.drop_front(),
2192  strides.drop_front(), outValues);
2193  }
2194 }
2195 
2196 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2197 /// folded operation might introduce more constant data; Users can control
2198 /// their heuristics by the control function.
2199 class ConstantOpExtractSliceFolder final
2200  : public OpRewritePattern<ExtractSliceOp> {
2201 public:
2203 
2204  ConstantOpExtractSliceFolder(MLIRContext *context,
2206  : OpRewritePattern<ExtractSliceOp>(context),
2207  controlFn(std::move(controlFn)) {}
2208 
2209  LogicalResult matchAndRewrite(ExtractSliceOp op,
2210  PatternRewriter &rewriter) const override {
2211  DenseElementsAttr attr;
2212  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2213  return failure();
2214 
2215  // A constant splat is handled by fold().
2216  if (attr.isSplat())
2217  return failure();
2218 
2219  // Dynamic result shape is not supported.
2220  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2221  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2222  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2223  return failure();
2224 
2225  // Customized control over the folding.
2226  if (!controlFn(op))
2227  return failure();
2228 
2229  int64_t count = sourceType.getNumElements();
2230  if (count == 0)
2231  return failure();
2232 
2233  // Check if there are any dynamic parts, which are not supported.
2234  auto offsets = op.getStaticOffsets();
2235  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2236  return failure();
2237  auto sizes = op.getStaticSizes();
2238  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2239  return failure();
2240  auto strides = op.getStaticStrides();
2241  if (llvm::is_contained(strides, ShapedType::kDynamic))
2242  return failure();
2243 
2244  // Compute the stride for each dimension.
2245  SmallVector<int64_t> counts;
2246  ArrayRef<int64_t> shape = sourceType.getShape();
2247  counts.reserve(shape.size());
2248  for (int64_t v : shape) {
2249  count = count / v;
2250  counts.push_back(count);
2251  }
2252 
2253  // New attribute constructed by the sliced values.
2254  DenseElementsAttr newAttr;
2255 
2256  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2257  SmallVector<APInt> outValues;
2258  outValues.reserve(sourceType.getNumElements());
2259  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2260  elems.begin(), counts, offsets, sizes, strides, &outValues);
2261  newAttr = DenseElementsAttr::get(resultType, outValues);
2262  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2263  SmallVector<APFloat> outValues;
2264  outValues.reserve(sourceType.getNumElements());
2265  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2266  elems.begin(), counts, offsets, sizes, strides, &outValues);
2267  newAttr = DenseElementsAttr::get(resultType, outValues);
2268  }
2269 
2270  if (newAttr) {
2271  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2272  return success();
2273  }
2274 
2275  return failure();
2276  }
2277 
2278 private:
2279  /// This additionally controls whether the fold happens or not. Users can
2280  /// impose their heuristics in the function.
2282 };
2283 
2284 } // namespace
2285 
2287  RewritePatternSet &patterns,
2288  const ControlConstantExtractSliceFusionFn &controlFn) {
2289  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2290 }
2291 
2292 /// Return the canonical type of the result of an extract_slice op.
2294  RankedTensorType operator()(ExtractSliceOp op,
2295  ArrayRef<OpFoldResult> mixedOffsets,
2296  ArrayRef<OpFoldResult> mixedSizes,
2297  ArrayRef<OpFoldResult> mixedStrides) {
2298  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2299  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2300  mixedStrides);
2301  }
2302 };
2303 
2304 /// A canonicalizer wrapper to replace ExtractSliceOps.
2306  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2307  ExtractSliceOp newOp) {
2308  Value replacement = newOp.getResult();
2309  if (replacement.getType() != op.getType())
2310  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2311  replacement);
2312  rewriter.replaceOp(op, replacement);
2313  }
2314 };
2315 
2316 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2317  MLIRContext *context) {
2318  results.add<
2321  ExtractSliceOpCastFolder>(context);
2322 }
2323 
2324 //
2325 static LogicalResult
2326 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2327  ShapedType shapedType) {
2328  OpBuilder b(op.getContext());
2329  for (OpFoldResult ofr : op.getMixedOffsets())
2330  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2331  return failure();
2332  // Rank-reducing noops only need to inspect the leading dimensions:
2333  // llvm::zip is appropriate.
2334  auto shape = shapedType.getShape();
2335  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2336  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2337  return failure();
2338  for (OpFoldResult ofr : op.getMixedStrides())
2339  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2340  return failure();
2341  return success();
2342 }
2343 
2344 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2345 /// slice, we can return the InsertSliceOp's source directly.
2346 // TODO: This only checks the immediate producer; extend to go up the
2347 // insert/extract chain if the slices are disjoint.
2348 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2349  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2350 
2351  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2352  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2353  insertOp.isSameAs(extractOp, isSame))
2354  return insertOp.getSource();
2355 
2356  return {};
2357 }
2358 
2359 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2360  if (OpFoldResult reshapedSource = reshapeConstantSource(
2361  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2362  getResult().getType()))
2363  return reshapedSource;
2364  if (getSourceType() == getType() &&
2366  return this->getSource();
2367  if (Value slice = foldExtractAfterInsertSlice(*this))
2368  return slice;
2369 
2370  return OpFoldResult();
2371 }
2372 
2374  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2375  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2376  unsigned rank = rankedTensorType.getRank();
2377  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2378  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2379  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2380  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2381  offsets, sizes, strides);
2382 }
2383 
2384 //===----------------------------------------------------------------------===//
2385 // InsertSliceOp
2386 //===----------------------------------------------------------------------===//
2387 
2388 void InsertSliceOp::getAsmResultNames(
2389  function_ref<void(Value, StringRef)> setNameFn) {
2390  setNameFn(getResult(), "inserted_slice");
2391 }
2392 
2393 // Build a InsertSliceOp with mixed static and dynamic entries.
2394 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2395  Value dest, ArrayRef<OpFoldResult> offsets,
2396  ArrayRef<OpFoldResult> sizes,
2397  ArrayRef<OpFoldResult> strides,
2398  ArrayRef<NamedAttribute> attrs) {
2399  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2400  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2401  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2402  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2403  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2404  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2405  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2406  b.getDenseI64ArrayAttr(staticSizes),
2407  b.getDenseI64ArrayAttr(staticStrides));
2408  result.addAttributes(attrs);
2409 }
2410 
2411 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2412 /// Range vector.
2413 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2414  Value dest, ArrayRef<Range> ranges,
2415  ArrayRef<NamedAttribute> attrs) {
2416  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2417  build(b, result, source, dest, offsets, sizes, strides, attrs);
2418 }
2419 
2420 // Build a InsertSliceOp with dynamic entries.
2421 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2422  Value dest, ValueRange offsets, ValueRange sizes,
2423  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2424  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2425  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2426  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2427  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2428  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2429  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2430  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2431 }
2432 
2433 /// Rank-reducing type verification for both InsertSliceOp and
2434 /// ParallelInsertSliceOp.
2436  RankedTensorType srcType, RankedTensorType dstType,
2437  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2438  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2439  // insert_slice is the inverse of extract_slice, use the same type
2440  // inference.
2441  RankedTensorType expected = ExtractSliceOp::inferResultType(
2442  dstType, staticOffsets, staticSizes, staticStrides);
2443  if (expectedType)
2444  *expectedType = expected;
2445  return isRankReducedType(expected, srcType);
2446 }
2447 
2448 /// Verifier for InsertSliceOp.
2450  RankedTensorType expectedType;
2451  SliceVerificationResult result =
2452  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2453  getStaticSizes(), getStaticStrides(), &expectedType);
2454  return produceSliceErrorMsg(result, *this, expectedType);
2455 }
2456 
2457 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2458 /// can mutate the second InsertSliceOp's destination to the first one's.
2459 ///
2460 /// Example:
2461 ///
2462 /// ```mlir
2463 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2464 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2465 /// ```
2466 ///
2467 /// folds into:
2468 ///
2469 /// ```mlir
2470 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2471 /// ```
2472 ///
2473 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2474 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2475  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2476 
2477  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2478  if (!prevInsertOp ||
2479  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2480  !prevInsertOp.isSameAs(insertOp, isSame))
2481  return failure();
2482 
2483  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2484  return success();
2485 }
2486 
2487 /// Folds round-trip extract/insert slice op pairs.
2488 /// Example:
2489 /// ```mlir
2490 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2491 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2492 /// ```
2493 /// can be folded into %val.
2494 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2495  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2496 
2497  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2498  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2499  !extractOp.isSameAs(insertOp, isSame))
2500  return nullptr;
2501 
2502  return extractOp.getSource();
2503 }
2504 
2505 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2506  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2507  getSourceType() == getType() &&
2509  return this->getSource();
2511  return getResult();
2512  if (auto result = foldInsertAfterExtractSlice(*this))
2513  return result;
2514  return OpFoldResult();
2515 }
2516 
2518  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2519  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2520  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2521  return success();
2522 }
2523 
2524 namespace {
2525 /// Pattern to rewrite a insert_slice op with constant arguments.
2526 ///
2527 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2528 template <typename InsertOpTy>
2529 class InsertSliceOpConstantArgumentFolder final
2530  : public OpRewritePattern<InsertOpTy> {
2531 public:
2533 
2534  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2535  PatternRewriter &rewriter) const override {
2536  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2537  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2538  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2539 
2540  // No constant operands were folded, just return;
2541  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2542  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2543  failed(foldDynamicStrideList(mixedStrides)))
2544  return failure();
2545 
2546  // Create the new op in canonical form.
2547  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2548  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2549  mixedOffsets, mixedSizes, mixedStrides);
2550  Value toInsert = insertSliceOp.getSource();
2551  if (sourceType != insertSliceOp.getSourceType()) {
2552  OpBuilder::InsertionGuard g(rewriter);
2553  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2554  // is that the insertion point is just before the ParallelCombiningOp in
2555  // the parallel case.
2556  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2557  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2558  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2559  sourceType, toInsert);
2560  }
2561  rewriter.replaceOpWithNewOp<InsertOpTy>(
2562  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2563  mixedSizes, mixedStrides);
2564  return success();
2565  }
2566 };
2567 
2568 /// Fold tensor_casts with insert_slice operations. If the source or
2569 /// destination tensor is a tensor_cast that removes static type information,
2570 /// the cast is folded into the insert_slice operation. E.g.:
2571 ///
2572 /// ```mlir
2573 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2574 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2575 /// ```
2576 ///
2577 /// folds into:
2578 ///
2579 /// ```mlir
2580 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2581 /// ```
2582 ///
2583 /// Note: When folding a cast on the destination tensor, the result of the
2584 /// insert_slice operation is casted to ensure that the type of the result did
2585 /// not change.
2586 ///
2587 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2588 template <typename InsertOpTy>
2589 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2591 
2592  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2593  PatternRewriter &rewriter) const override {
2594  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2595  return matchPattern(operand, matchConstantIndex());
2596  }))
2597  return failure();
2598 
2599  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2600  auto castOp = v.getDefiningOp<tensor::CastOp>();
2601  if (!castOp || !canFoldIntoConsumerOp(castOp))
2602  return std::nullopt;
2603  return castOp.getSource();
2604  };
2605  std::optional<Value> sourceCastSource =
2606  getSourceOfCastOp(insertSliceOp.getSource());
2607  std::optional<Value> destCastSource =
2608  getSourceOfCastOp(insertSliceOp.getDest());
2609  if (!sourceCastSource && !destCastSource)
2610  return failure();
2611 
2612  auto src =
2613  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2614  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2615  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2616  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2617  if (!srcType || !dstType)
2618  return failure();
2619  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2620  insertSliceOp.getStaticSizes(),
2621  insertSliceOp.getStaticStrides()) !=
2623  return failure();
2624 
2625  Operation *replacement = rewriter.create<InsertOpTy>(
2626  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2627  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2628 
2629  // In the parallel case there is no result and so nothing to cast.
2630  bool isParallelInsert =
2631  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2632  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2633  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2634  insertSliceOp.getDestType(),
2635  replacement->getResult(0));
2636  }
2637  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2638  return success();
2639  }
2640 };
2641 
2642 /// If additional static type information can be deduced from a insert_slice's
2643 /// size operands, insert an explicit cast of the op's source operand. This
2644 /// enables other canonicalization patterns that are matching for tensor_cast
2645 /// ops such as `ForOpTensorCastFolder` in SCF.
2646 ///
2647 /// Example:
2648 ///
2649 /// ```mlir
2650 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2651 /// : tensor<?x?xf32> into ...
2652 /// ```
2653 ///
2654 /// folds into:
2655 ///
2656 /// ```mlir
2657 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2658 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2659 /// : tensor<64x64xf32> into ...
2660 /// ```
2661 ///
2662 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2663 template <typename InsertOpTy>
2664 struct InsertSliceOpSourceCastInserter final
2665  : public OpRewritePattern<InsertOpTy> {
2667 
2668  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2669  PatternRewriter &rewriter) const override {
2670  RankedTensorType srcType = insertSliceOp.getSourceType();
2671  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2672  return failure();
2673  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
2674  srcType.getShape().end());
2675  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2676  if (std::optional<int64_t> constInt =
2677  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2678  // Bail on invalid IR.
2679  if (*constInt < 0)
2680  return failure();
2681  newSrcShape[i] = *constInt;
2682  }
2683  }
2684  if (!hasValidSizesOffsets(newSrcShape))
2685  return failure();
2686 
2687  RankedTensorType newSrcType = RankedTensorType::get(
2688  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2689  if (srcType == newSrcType ||
2690  !preservesStaticInformation(srcType, newSrcType) ||
2691  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2692  return failure();
2693 
2694  // newSrcType is:
2695  // 1) Different from srcType.
2696  // 2) "More static" than srcType.
2697  // 3) Cast-compatible with srcType.
2698  // Insert the cast.
2699  OpBuilder::InsertionGuard g(rewriter);
2700  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2701  // that the insertion point is just before the ParallelCombiningOp in the
2702  // parallel case.
2703  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2704  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2705  Value cast = rewriter.create<tensor::CastOp>(
2706  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2707  rewriter.replaceOpWithNewOp<InsertOpTy>(
2708  insertSliceOp, cast, insertSliceOp.getDest(),
2709  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2710  insertSliceOp.getMixedStrides());
2711  return success();
2712  }
2713 };
2714 } // namespace
2715 
2716 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2717  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2718 }
2719 
2720 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2721  MLIRContext *context) {
2722  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2723  InsertSliceOpCastFolder<InsertSliceOp>,
2724  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2725 }
2726 
2728  Location loc,
2729  Value tensor,
2730  Value dest) {
2731  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
2732  unsigned rank = rankedTensorType.getRank();
2733  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2734  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
2735  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2736  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2737  sizes, strides);
2738 }
2739 
2740 //===----------------------------------------------------------------------===//
2741 // PadOp
2742 //===----------------------------------------------------------------------===//
2743 
2744 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
2745  setNameFn(getResult(), "padded");
2746 }
2747 
2748 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
2749 // supports optional types.
2750 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
2751  Type typeToInfer, Type typeToInferFrom) {}
2752 
2755  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2756  Type &typeToInfer, Type typeToInferFrom) {
2757  if (optOperand)
2758  typeToInfer = typeToInferFrom;
2759  return success();
2760 }
2761 
2763  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2764  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2765  auto expectedType =
2766  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2767  if (!expectedType) {
2768  return emitError("failed to infer expectedType from sourceType ")
2769  << sourceType << ", specified resultType is " << resultType;
2770  }
2771  if (resultType.getRank() != expectedType.getRank()) {
2772  return emitError("specified type ")
2773  << resultType << " does not match the inferred type "
2774  << expectedType;
2775  }
2776  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
2777  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2778  continue;
2779  if (expectedType.isDynamicDim(i))
2780  continue;
2781  return emitError("specified type ")
2782  << resultType << " does not match the inferred type "
2783  << expectedType;
2784  }
2785 
2786  return success();
2787 }
2788 
2789 LogicalResult PadOp::verifyRegions() {
2790  auto &region = getRegion();
2791  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2792  Block &block = region.front();
2793  if (block.getNumArguments() != rank)
2794  return emitError("expected the block to have ") << rank << " arguments";
2795 
2796  // Note: the number and type of yield values are checked in the YieldOp.
2797  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
2798  if (!en.value().isIndex())
2799  return emitOpError("expected block argument ")
2800  << (en.index() + 1) << " to be an index";
2801  }
2802 
2803  // Ensure that the region yields an element of the right type.
2804  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
2805  if (yieldOp.getValue().getType() !=
2806  llvm::cast<ShapedType>(getType()).getElementType())
2807  return emitOpError("expected yield type to match shape element type");
2808 
2809  return success();
2810 }
2811 
2812 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2813  ArrayRef<int64_t> staticLow,
2814  ArrayRef<int64_t> staticHigh,
2815  ArrayRef<int64_t> resultShape) {
2816  unsigned rank = sourceType.getRank();
2817  if (staticLow.size() != rank)
2818  return RankedTensorType();
2819  if (staticHigh.size() != rank)
2820  return RankedTensorType();
2821  if (!resultShape.empty() && resultShape.size() != rank)
2822  return RankedTensorType();
2823 
2824  SmallVector<int64_t, 4> inferredShape;
2825  for (auto i : llvm::seq<unsigned>(0, rank)) {
2826  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2827  staticHigh[i] == ShapedType::kDynamic) {
2828  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2829  : resultShape[i]);
2830  } else {
2831  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2832  assert((resultShape.empty() || size == resultShape[i] ||
2833  resultShape[i] == ShapedType::kDynamic) &&
2834  "mismatch between inferred shape and result shape");
2835  inferredShape.push_back(size);
2836  }
2837  }
2838 
2839  return RankedTensorType::get(inferredShape, sourceType.getElementType());
2840 }
2841 
2842 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2843  Value source, ArrayRef<int64_t> staticLow,
2844  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
2845  bool nofold, ArrayRef<NamedAttribute> attrs) {
2846  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2847  if (!resultType)
2848  resultType = inferResultType(sourceType, staticLow, staticHigh);
2849  build(b, result, resultType, source, low, high,
2850  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2851  nofold ? b.getUnitAttr() : UnitAttr());
2852  result.addAttributes(attrs);
2853 }
2854 
2855 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2856  Value source, ValueRange low, ValueRange high, bool nofold,
2857  ArrayRef<NamedAttribute> attrs) {
2858  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2859  unsigned rank = sourceType.getRank();
2860  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
2861  build(b, result, resultType, source, staticVector, staticVector, low, high,
2862  nofold, attrs);
2863 }
2864 
2865 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2866  Value source, ArrayRef<OpFoldResult> low,
2867  ArrayRef<OpFoldResult> high, bool nofold,
2868  ArrayRef<NamedAttribute> attrs) {
2869  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
2870  SmallVector<Value, 4> dynamicLow, dynamicHigh;
2871  SmallVector<int64_t, 4> staticLow, staticHigh;
2872  // staticLow and staticHigh have full information of the padding config.
2873  // This will grow staticLow and staticHigh with 1 value. If the config is
2874  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
2875  // value as well.
2876  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
2877  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
2878  if (!resultType) {
2879  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2880  }
2881  assert(llvm::isa<RankedTensorType>(resultType));
2882  build(b, result, resultType, source, dynamicLow, dynamicHigh,
2883  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
2884  nofold ? b.getUnitAttr() : UnitAttr());
2885  result.addAttributes(attrs);
2886 }
2887 
2888 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
2889  Value source, ArrayRef<OpFoldResult> low,
2890  ArrayRef<OpFoldResult> high, Value constantPadValue,
2891  bool nofold, ArrayRef<NamedAttribute> attrs) {
2892  build(b, result, resultType, source, low, high, nofold, attrs);
2893 
2894  // Add a region and a block to yield the pad value.
2895  Region *region = result.regions[0].get();
2896  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
2897  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
2898  SmallVector<Location> blockArgLocs(sourceRank, result.location);
2899 
2900  // `builder.createBlock` changes the insertion point within the block. Create
2901  // a guard to reset the insertion point of the builder after it is destroyed.
2902  OpBuilder::InsertionGuard guard(b);
2903  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
2904  b.create<tensor::YieldOp>(result.location, constantPadValue);
2905 }
2906 
2907 llvm::SmallBitVector PadOp::getPaddedDims() {
2908  llvm::SmallBitVector paddedDims(getSourceType().getRank());
2909  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
2910  for (const auto &en : enumerate(paddingWidths))
2911  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
2912  paddedDims.set(en.index());
2913  };
2914  extractPaddedDims(getMixedLowPad());
2915  extractPaddedDims(getMixedHighPad());
2916  return paddedDims;
2917 }
2918 
2919 namespace {
2920 // Folds tensor.pad when padding is static zeros and the attribute
2921 // doesn't request otherwise.
2922 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
2924 
2925  LogicalResult matchAndRewrite(PadOp padTensorOp,
2926  PatternRewriter &rewriter) const override {
2927  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2928  return failure();
2929  if (padTensorOp.getNofold())
2930  return failure();
2931  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2932  padTensorOp, padTensorOp.getResult().getType(),
2933  padTensorOp.getSource());
2934  return success();
2935  }
2936 };
2937 
2938 // Fold CastOp into PadOp when adding static information.
2939 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
2941 
2942  LogicalResult matchAndRewrite(PadOp padTensorOp,
2943  PatternRewriter &rewriter) const override {
2944  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2945  if (!tensor::canFoldIntoConsumerOp(castOp))
2946  return failure();
2947 
2948  auto newResultType = PadOp::inferResultType(
2949  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
2950  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2951  padTensorOp.getResultType().getShape());
2952 
2953  if (newResultType == padTensorOp.getResultType()) {
2954  rewriter.modifyOpInPlace(padTensorOp, [&]() {
2955  padTensorOp.getSourceMutable().assign(castOp.getSource());
2956  });
2957  } else {
2958  auto newOp = rewriter.create<PadOp>(
2959  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2960  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2961  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2962  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2963  IRMapping mapper;
2964  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2965 
2966  rewriter.replaceOpWithNewOp<tensor::CastOp>(
2967  padTensorOp, padTensorOp.getResultType(), newOp);
2968  }
2969  return success();
2970  }
2971 };
2972 
2973 // Fold CastOp using the result of PadOp back into the latter if it adds
2974 // static information.
2975 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
2977 
2978  LogicalResult matchAndRewrite(PadOp padTensorOp,
2979  PatternRewriter &rewriter) const override {
2980  if (!padTensorOp.getResult().hasOneUse())
2981  return failure();
2982  auto tensorCastOp =
2983  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2984  if (!tensorCastOp)
2985  return failure();
2986  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
2987  tensorCastOp.getDest().getType()))
2988  return failure();
2989 
2990  auto replacementOp = rewriter.create<PadOp>(
2991  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2992  padTensorOp.getSource(), padTensorOp.getStaticLow(),
2993  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2994  padTensorOp.getHigh(), padTensorOp.getNofold(),
2995  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
2996  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
2997 
2998  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
2999  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3000  return success();
3001  }
3002 };
3003 
3004 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3005 /// different dimensions. The pattern applies if the following preconditions
3006 /// hold:
3007 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3008 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3009 /// 3) the tensor::PadOps perform only high-padding,
3010 /// 4) the tensor::PadOps have the same constant padding value,
3011 /// 5) the tensor::PadOps do not have common padding dimensions,
3012 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3013 /// zero-offset for every dimension.
3014 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3015 /// the
3016 /// padded source dimensions.
3017 ///
3018 /// Example:
3019 ///
3020 /// ```mlir
3021 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3022 /// : tensor<64x64xf32> to tensor<?x64xf32>
3023 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3024 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3025 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3026 /// : tensor<8x64xf32> to tensor<8x?xf32>
3027 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3028 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3029 /// ```
3030 ///
3031 /// folds into:
3032 ///
3033 /// ```mlir
3034 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3035 /// : tensor<64x64xf32> to tensor<?x?xf32>
3036 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3037 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3038 /// ```
3039 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3041 
3042  LogicalResult matchAndRewrite(PadOp padOp,
3043  PatternRewriter &rewriter) const override {
3044  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3045  if (!innerSliceOp)
3046  return failure();
3047  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3048  if (!outerPadOp || outerPadOp.getNofold())
3049  return failure();
3050  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3051  if (!outerSliceOp)
3052  return failure();
3053 
3054  // 1) Fail if the chain is rank-reducing.
3055  int64_t rank = padOp.getSourceType().getRank();
3056  if (outerSliceOp.getSourceType().getRank() != rank) {
3057  return rewriter.notifyMatchFailure(padOp,
3058  "cannot fold rank-reducing chain");
3059  }
3060 
3061  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3062  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3063  return rewriter.notifyMatchFailure(
3064  padOp, "cannot fold non-unit stride ExtractSliceOps");
3065  }
3066 
3067  // 3) Fail if the tensor::PadOps have non-zero low padding.
3068  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3069  return rewriter.notifyMatchFailure(padOp,
3070  "cannot fold PadOps with low padding");
3071  }
3072 
3073  // 4) Fail if the tensor::PadOps padding values do not match.
3074  Attribute innerAttr, outerAttr;
3075  Value innerValue = padOp.getConstantPaddingValue();
3076  Value outerValue = outerPadOp.getConstantPaddingValue();
3077  if (!innerValue || !outerValue ||
3078  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3079  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3080  innerAttr != outerAttr) {
3081  return rewriter.notifyMatchFailure(
3082  padOp, "cannot fold PadOps with different padding values");
3083  }
3084 
3085  // 5) Fail if a dimension is padded by both tensor::PadOps.
3086  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3087  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3088  if (innerDims.anyCommon(outerDims)) {
3089  return rewriter.notifyMatchFailure(
3090  padOp, "cannot fold PadOps with common padding dimensions");
3091  }
3092 
3093  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3094  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3095  // for every dimension, and use the offset the other pair. Fail if no
3096  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3097  // exists.
3098  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3099  for (auto en : enumerate(newOffsets)) {
3100  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3101  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3102  if (!innerDims.test(en.index()) &&
3103  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3104  en.value() = outerOffset;
3105  continue;
3106  }
3107  if (!outerDims.test(en.index()) &&
3108  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3109  en.value() = innerOffset;
3110  continue;
3111  }
3112  return rewriter.notifyMatchFailure(
3113  padOp, "cannot find zero-offset and zero-padding pair");
3114  }
3115 
3116  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3117  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3118  // outer tensor::PadOp and fail if the size of the inner
3119  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3120  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3121  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3122  for (auto en : enumerate(newSizes)) {
3123  if (!outerDims.test(en.index()))
3124  continue;
3125  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3126  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3127  assert(!ShapedType::isDynamic(sourceSize) &&
3128  "expected padded dimension to have a static size");
3129  if (getConstantIntValue(sliceSize) != sourceSize) {
3130  return rewriter.notifyMatchFailure(
3131  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3132  "match the size of the outer padding");
3133  }
3134  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3135  }
3136 
3137  // Combine the high paddings of the two tensor::PadOps.
3138  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3139  for (auto en : enumerate(newHighPad)) {
3140  if (innerDims.test(en.index()))
3141  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3142  if (outerDims.test(en.index()))
3143  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3144  }
3145 
3146  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3147  // the two paddings in one step.
3148  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3149  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3150  innerSliceOp.getMixedStrides());
3151  auto newPadOp = rewriter.create<PadOp>(
3152  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3153  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3154  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3155  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3156  newPadOp.getRegion().begin());
3157  rewriter.replaceOp(padOp, newPadOp.getResult());
3158  return success();
3159  }
3160 };
3161 
3162 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3164 
3165  LogicalResult matchAndRewrite(PadOp padTensorOp,
3166  PatternRewriter &rewriter) const override {
3167  Value input = padTensorOp.getSource();
3168  if (!llvm::isa<RankedTensorType>(input.getType()))
3169  return failure();
3170  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3171  auto inputRank = inputDims.size();
3172 
3173  auto oldResultType =
3174  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3175  if (!oldResultType)
3176  return failure();
3177 
3178  auto outputDims = oldResultType.getShape();
3179 
3180  // Extract the static info from the high and low operands.
3181  SmallVector<int64_t> constOperandsLow;
3182  SmallVector<Value> newLows;
3183  for (auto operand : padTensorOp.getLow()) {
3184  APSInt intOp;
3185  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3186  constOperandsLow.push_back(ShapedType::kDynamic);
3187  newLows.push_back(operand);
3188  continue;
3189  }
3190  constOperandsLow.push_back(intOp.getExtValue());
3191  }
3192  SmallVector<int64_t> constOperandsHigh;
3193  SmallVector<Value> newHighs;
3194  for (auto operand : padTensorOp.getHigh()) {
3195  APSInt intOp;
3196  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3197  constOperandsHigh.push_back(ShapedType::kDynamic);
3198  newHighs.push_back(operand);
3199  continue;
3200  }
3201  constOperandsHigh.push_back(intOp.getExtValue());
3202  }
3203 
3204  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3205  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3206 
3207  // Verify the op is well-formed.
3208  if (inputDims.size() != outputDims.size() ||
3209  inputDims.size() != constLow.size() ||
3210  inputDims.size() != constHigh.size())
3211  return failure();
3212 
3213  auto lowCount = 0;
3214  auto highCount = 0;
3215  for (size_t i = 0; i < inputRank; i++) {
3216  if (constLow[i] == ShapedType::kDynamic)
3217  constLow[i] = constOperandsLow[lowCount++];
3218  if (constHigh[i] == ShapedType::kDynamic)
3219  constHigh[i] = constOperandsHigh[highCount++];
3220  }
3221 
3222  auto staticLow = ArrayRef<int64_t>(constLow);
3223  auto staticHigh = ArrayRef<int64_t>(constHigh);
3224 
3225  // Calculate the output sizes with the static information.
3226  SmallVector<int64_t> newOutDims;
3227  for (size_t i = 0; i < inputRank; i++) {
3228  if (outputDims[i] == ShapedType::kDynamic) {
3229  newOutDims.push_back(
3230  (staticLow[i] == ShapedType::kDynamic ||
3231  staticHigh[i] == ShapedType::kDynamic ||
3232  inputDims[i] == ShapedType::kDynamic
3233  ? ShapedType::kDynamic
3234  : inputDims[i] + staticLow[i] + staticHigh[i]));
3235  } else {
3236  newOutDims.push_back(outputDims[i]);
3237  }
3238  }
3239 
3240  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3241  llvm::all_of(newOutDims,
3242  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3243  return failure();
3244 
3245  // Rewrite the op using the new static type.
3246  auto newResultType = RankedTensorType::get(
3247  newOutDims, padTensorOp.getType().getElementType());
3248  auto newOp = rewriter.create<PadOp>(
3249  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3250  newLows, newHighs, padTensorOp.getNofold(),
3251  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3252 
3253  IRMapping mapper;
3254  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3255  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3256  newOp);
3257 
3258  return success();
3259  }
3260 };
3261 
3262 } // namespace
3263 
3264 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3265  MLIRContext *context) {
3266  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3267  FoldOrthogonalPaddings, FoldStaticPadding>(context);
3268 }
3269 
3270 /// Return the padding value of the PadOp if it constant. In this context,
3271 /// "constant" means an actual constant or "defined outside of the block".
3272 ///
3273 /// Values are considered constant in three cases:
3274 /// - A ConstantLike value.
3275 /// - A basic block argument from a different block.
3276 /// - A value defined outside of the block.
3277 ///
3278 /// If the padding value is not constant, an empty Value is returned.
3279 Value PadOp::getConstantPaddingValue() {
3280  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3281  if (!yieldOp)
3282  return {};
3283  Value padValue = yieldOp.getValue();
3284  // Check if yield value is a constant.
3285  if (matchPattern(padValue, m_Constant()))
3286  return padValue;
3287  // Check if yield value is defined inside the PadOp block.
3288  if (padValue.getParentBlock() == &getRegion().front())
3289  return {};
3290  // Else: Yield value defined outside of the PadOp block.
3291  return padValue;
3292 }
3293 
3294 OpFoldResult PadOp::fold(FoldAdaptor) {
3295  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3296  !getNofold())
3297  return getSource();
3298  return {};
3299 }
3300 
3301 //===----------------------------------------------------------------------===//
3302 // ParallelInsertSliceOp
3303 //===----------------------------------------------------------------------===//
3304 
3305 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3306  ParallelCombiningOpInterface parallelCombiningParent =
3307  getParallelCombiningParent();
3308  for (const auto &it :
3309  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3310  Operation &nextOp = it.value();
3311  if (&nextOp == getOperation())
3312  return parallelCombiningParent.getParentResult(it.index());
3313  }
3314  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3315 }
3316 
3317 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3318 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3319  Value source, Value dest,
3320  ArrayRef<OpFoldResult> offsets,
3321  ArrayRef<OpFoldResult> sizes,
3322  ArrayRef<OpFoldResult> strides,
3323  ArrayRef<NamedAttribute> attrs) {
3324  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3325  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3326  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3327  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3328  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3329  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3330  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3331  b.getDenseI64ArrayAttr(staticSizes),
3332  b.getDenseI64ArrayAttr(staticStrides));
3333  result.addAttributes(attrs);
3334 }
3335 
3336 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3337 /// packed into a Range vector.
3338 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3339  Value source, Value dest,
3340  ArrayRef<Range> ranges,
3341  ArrayRef<NamedAttribute> attrs) {
3342  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3343  build(b, result, source, dest, offsets, sizes, strides, attrs);
3344 }
3345 
3346 // Build a ParallelInsertSliceOp with dynamic entries.
3347 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3348  Value source, Value dest, ValueRange offsets,
3349  ValueRange sizes, ValueRange strides,
3350  ArrayRef<NamedAttribute> attrs) {
3351  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3352  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3353  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3354  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3355  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3356  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3357  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3358 }
3359 
3361  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3362  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3363  << *(getOperation()->getParentOp());
3364 
3365  RankedTensorType expectedType;
3366  SliceVerificationResult result =
3367  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3368  getStaticSizes(), getStaticStrides(), &expectedType);
3369  return produceSliceErrorMsg(result, *this, expectedType);
3370 }
3371 
3372 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3373  RewritePatternSet &results, MLIRContext *context) {
3374  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3375  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3376  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3377 }
3378 
3379 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3380  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3381 }
3382 
3383 //===----------------------------------------------------------------------===//
3384 // ScatterOp
3385 //===----------------------------------------------------------------------===//
3386 
3387 void ScatterOp::getAsmResultNames(
3388  function_ref<void(Value, StringRef)> setNameFn) {
3389  setNameFn(getResult(), "scatter");
3390 }
3391 
3393  int64_t destRank = getDestType().getRank();
3394  ArrayRef<int64_t> scatterDims = getScatterDims();
3395  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3396  "scatter", "dest")))
3397  return failure();
3398 
3399  if (!getUnique())
3400  return emitOpError("requires 'unique' attribute to be set");
3401  // TODO: we could also check statically that there are fewer leading index
3402  // tensor dims than the dest dims. If this is not the case, the unique
3403  // attribute cannot be true.
3404 
3405  // Use the GatherOp::inferResultType on the `dest` type and verify the
3406  // expected type matches the source type.
3407  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3408  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3409  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3410  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3411  if (getSourceType() != expectedSourceType &&
3412  getSourceType() != expectedRankReducedSourceType) {
3413  return emitOpError("source type "
3414  "mismatch: "
3415  "expected ")
3416  << expectedSourceType << " or its rank-reduced variant "
3417  << expectedRankReducedSourceType << " (got: " << getSourceType()
3418  << ")";
3419  }
3420 
3421  return success();
3422 }
3423 
3424 //===----------------------------------------------------------------------===//
3425 // SplatOp
3426 //===----------------------------------------------------------------------===//
3427 
3428 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3429  Type aggregateType, ValueRange dynamicSizes) {
3430  build(builder, result, aggregateType, element, dynamicSizes);
3431 }
3432 
3433 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3434  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3435  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3436  build(builder, result, aggregateType, element, dynamicSizes);
3437 }
3438 
3439 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3440  ArrayRef<OpFoldResult> sizes) {
3441  SmallVector<int64_t> staticShape;
3442  SmallVector<Value> dynamicSizes;
3443  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3444  build(builder, result, element, staticShape, dynamicSizes);
3445 }
3446 
3447 void SplatOp::getAsmResultNames(
3448  function_ref<void(Value, StringRef)> setNameFn) {
3449  setNameFn(getResult(), "splat");
3450 }
3451 
3453  if (getType().getNumDynamicDims() !=
3454  static_cast<int64_t>(getDynamicSizes().size()))
3455  return emitOpError("incorrect number of dynamic sizes, has ")
3456  << getDynamicSizes().size() << ", expected "
3457  << getType().getNumDynamicDims();
3458  return success();
3459 }
3460 
3463  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3464  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3465  unsigned ctr = 0;
3466  for (int64_t i = 0; i < getType().getRank(); ++i) {
3467  if (getType().isDynamicDim(i)) {
3468  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3469  } else {
3470  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3471  }
3472  }
3473  return success();
3474 }
3475 
3476 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3477  auto constOperand = adaptor.getInput();
3478  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3479  return {};
3480 
3481  // Do not fold if the splat is not statically shaped
3482  if (!getType().hasStaticShape())
3483  return {};
3484 
3485  // SplatElementsAttr::get treats single value for second arg as being a
3486  // splat.
3487  return SplatElementsAttr::get(getType(), {constOperand});
3488 }
3489 
3490 //===----------------------------------------------------------------------===//
3491 // PackOp/UnPackOp Common
3492 //===----------------------------------------------------------------------===//
3493 
3494 template <typename OpTy>
3495 static LogicalResult
3497  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3498  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3499  "applies to only pack or unpack operations");
3500  int64_t destRank = op.getDestRank();
3501  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3502  reifiedReturnShapes[0] =
3503  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3504  return success();
3505 }
3506 
3507 template <typename OpTy>
3509  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3510  "applies to only pack or unpack operations");
3511  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3512  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3513  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3514  assert(tiles.size() == dimsToTile.size() &&
3515  "tiles must match indices of dimension to block");
3516  // bind the dimension `i` with the tile factor.
3517  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3518  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3519  return dimAndTileMapping;
3520 }
3521 
3522 template <typename OpTy>
3524  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3525  "applies to only pack or unpack operations");
3526  Builder builder(op);
3527  SmallVector<OpFoldResult> mixedInnerTiles;
3528  unsigned dynamicValIndex = 0;
3529  for (int64_t staticTile : op.getStaticInnerTiles()) {
3530  if (!ShapedType::isDynamic(staticTile))
3531  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3532  else
3533  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3534  }
3535  return mixedInnerTiles;
3536 }
3537 
3538 template <typename OpTy>
3540  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3541  "applies to only pack or unpack operations");
3542  SmallVector<Value> dynamicTiles;
3543  SmallVector<int64_t> staticTiles;
3544  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3545  return staticTiles;
3546 }
3547 
3548 /// Returns true if `dimsPos` is invalid. It is invalid when:
3549 /// a) It contains duplicate.
3550 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3551 /// c) The number of elements in `dimsPos` is > than `rank`.
3553  size_t rank) {
3554  size_t dimsPosSize = dimsPos.size();
3555  if (dimsPosSize > rank)
3556  return true;
3557  DenseSet<int64_t> uniqued;
3558  for (int64_t dim : dimsPos)
3559  uniqued.insert(dim);
3560  if (dimsPosSize != uniqued.size())
3561  return true;
3562  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3563  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3564  });
3565 }
3566 
3567 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3568 /// of the `limitShape`.
3569 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3570  ArrayRef<int64_t> limitShape) {
3571  assert(
3572  sourceShape.size() == limitShape.size() &&
3573  "expected source shape rank, and limit of the shape to have same rank");
3574  return llvm::all_of(
3575  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3576  int64_t sourceExtent = std::get<0>(it);
3577  int64_t limit = std::get<1>(it);
3578  return ShapedType::isDynamic(sourceExtent) ||
3579  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3580  });
3581 }
3582 
3583 template <typename OpTy>
3585  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3586  "applies to only pack or unpack operations");
3587  Operation *op = packOrUnPack.getOperation();
3588 
3589  // Return true if we have a zero-value tile.
3590  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3591  return llvm::any_of(
3592  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3593  };
3594 
3595  // Verify tiles. Do not allow zero tiles.
3596  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3597  if (hasZeros(mixedTiles))
3598  return op->emitError("invalid zero tile factor");
3599 
3600  // Verify inner_dims_pos and outer_dims_perm.
3601  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3602  ? packOrUnPack.getSourceType()
3603  : packOrUnPack.getDestType();
3604  size_t unpackedRank = unpackedType.getRank();
3605  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3606  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3607  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3608  return op->emitError("invalid inner_dims_pos vector");
3609  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3610  return op->emitError("invalid outer_dims_perm vector");
3611  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3612  return op->emitError("outer_dims_perm must be a permutation or empty");
3613 
3614  // Tiling factors must be less than or equal to the input rank for pack (or
3615  // output rank for unpack), and must match the number of `inner_dims_pos`.
3616  if (mixedTiles.size() > unpackedRank) {
3617  return op->emitError("tiling factors must be less than or equal to the "
3618  "input rank for pack or output rank for unpack");
3619  }
3620  if (mixedTiles.size() != innerDimsPos.size()) {
3621  return op->emitError(
3622  "tiling factors must equal the number of dimensions to tile");
3623  }
3624 
3625  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3626  ? packOrUnPack.getDestType()
3627  : packOrUnPack.getSourceType();
3628  size_t packedRank = packedType.getRank();
3629  // Require output rank to match input rank + number of blocking factors.
3630  if (unpackedRank + mixedTiles.size() != packedRank) {
3631  return op->emitError(
3632  "packed rank must equal unpacked rank + tiling factors");
3633  }
3634 
3635  // Verify result shape is greater than the minimum expected
3636  // by the pack operation, and that the output shape
3637  // represents full tiles.
3638  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3639  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3640  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3641  return op->emitError("the shape of output is not large enough to hold the "
3642  "packed data. Expected at least ")
3643  << expectedPackedType << ", got " << packedType;
3644  }
3645  if (!llvm::all_of(
3646  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3647  mixedTiles),
3648  [](std::tuple<int64_t, OpFoldResult> it) {
3649  std::optional<int64_t> constTileSize =
3650  getConstantIntValue(std::get<1>(it));
3651  int64_t shape = std::get<0>(it);
3652  if (!constTileSize) {
3653  // If specified tile size is dynamic, output shape should
3654  // be dynamic too.
3655  return ShapedType::isDynamic(shape);
3656  }
3657  if (ShapedType::isDynamic(shape)) {
3658  // For the shape being dynamic when tile size is
3659  // specified, return true. In canonical form a constant
3660  // tile size should lead to constant shape of the tiled
3661  // dimension, but not needed for verification.
3662  return true;
3663  }
3664  return shape == constTileSize.value();
3665  })) {
3666  return op->emitError("mismatch in inner tile sizes specified and shaped of "
3667  "tiled dimension in the packed type");
3668  }
3669  return success();
3670 }
3671 
3672 namespace {
3673 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
3674 /// various permutations to the op.
3675 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
3676 // these. These may or may not become true foldings / canonicalizations
3677 // depending on how aggressive we want to be in automatically folding
3678 // transposes.
3679 struct PackOrUnPackTransposeResult {
3680  SmallVector<int64_t> innerDimsPos;
3681  SmallVector<OpFoldResult> innerTiles;
3682  SmallVector<int64_t> outerDimsPerm;
3683 };
3684 } // namespace
3685 
3686 template <typename OpTy>
3687 static PackOrUnPackTransposeResult
3689  ArrayRef<int64_t> innerPermutation,
3690  ArrayRef<int64_t> outerPermutation) {
3691  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3692  "applies to only pack or unpack operations");
3693  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3694  "some permutation must be non-empty");
3695  PackOrUnPackTransposeResult metadata;
3696  metadata.innerDimsPos =
3697  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
3698  metadata.innerTiles =
3699  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
3700  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3701  ? packOrUnPackOp.getSourceRank()
3702  : packOrUnPackOp.getDestRank();
3703  metadata.outerDimsPerm =
3704  packOrUnPackOp.getOuterDimsPerm().empty()
3705  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3706  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
3707  if (!innerPermutation.empty()) {
3708  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3709  isPermutationVector(innerPermutation) &&
3710  "invalid inner permutation");
3711  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
3712  applyPermutationToVector(metadata.innerTiles, innerPermutation);
3713  }
3714  if (!outerPermutation.empty()) {
3715  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3716  isPermutationVector(outerPermutation) &&
3717  "invalid outer permutation");
3718  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
3719  }
3720  return metadata;
3721 }
3722 
3723 //===----------------------------------------------------------------------===//
3724 // PackOp
3725 //===----------------------------------------------------------------------===//
3726 
3727 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3728  setNameFn(getResult(), "pack");
3729 }
3730 
3731 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
3732  Value dest, ArrayRef<int64_t> innerDimsPos,
3733  ArrayRef<OpFoldResult> innerTiles,
3734  std::optional<Value> paddingValue,
3735  ArrayRef<int64_t> outerDimsPerm) {
3736  assert(innerDimsPos.size() == innerTiles.size() &&
3737  "number of tile sizes specified must match the specified number of "
3738  "original dimensions to be tiled");
3739  SmallVector<int64_t> staticTileSizes;
3740  SmallVector<Value> dynamicTileSizes;
3741  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
3742  build(builder, state, dest.getType(), source, dest,
3743  paddingValue ? *paddingValue : nullptr,
3744  outerDimsPerm.empty() ? nullptr
3745  : builder.getDenseI64ArrayAttr(outerDimsPerm),
3746  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
3747  builder.getDenseI64ArrayAttr(staticTileSizes));
3748 }
3749 
3752  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3753  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
3754 }
3755 
3756 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
3757  return getDimAndTileMappingImpl(*this);
3758 }
3759 
3760 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
3761  return getMixedTilesImpl(*this);
3762 }
3763 
3764 SmallVector<int64_t> PackOp::getStaticTiles() {
3765  return getStaticTilesImpl(*this);
3766 }
3767 
3768 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
3769  ArrayRef<int64_t> innerDimsPos,
3770  ArrayRef<int64_t> outputShape,
3771  ArrayRef<int64_t> outerDimsPerm,
3772  ArrayRef<OpFoldResult> innerTiles) {
3773  SmallVector<int64_t> outputTileSizes(
3774  outputShape.take_front(inputShape.size()));
3775  if (!outerDimsPerm.empty()) {
3776  assert(outerDimsPerm.size() == outputTileSizes.size() &&
3777  "expected output and outer_dims_perm to have same size");
3778  applyPermutationToVector(outputTileSizes,
3779  invertPermutationVector(outerDimsPerm));
3780  }
3781  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3782  if (ShapedType::isDynamic(inputShape[pos]))
3783  continue;
3784  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3785 
3786  if (!constantTile) {
3787  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3788  (inputShape[pos] % outputTileSizes[pos] != 0))
3789  return true;
3790  } else if (inputShape[pos] % (*constantTile) != 0) {
3791  return true;
3792  }
3793  }
3794  return false;
3795 }
3796 
3799  return failure();
3800 
3801  // Verify padding value, and bail out if the tile does not divide the
3802  // dimension fully. In the case of dynamic tile factors or dimensions, having
3803  // a partial tile is undefined behavior.
3804  auto paddingValue = getPaddingValue();
3805  if (paddingValue &&
3806  paddingValue.getType() != getSourceType().getElementType()) {
3807  return emitOpError("expected padding_value has ")
3808  << getSourceType().getElementType()
3809  << " but got: " << paddingValue.getType();
3810  }
3811 
3812  if (!paddingValue &&
3813  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3814  getDestType().getShape(), getOuterDimsPerm(),
3815  getMixedTiles())) {
3816  return emitOpError(
3817  "invalid tile factor or output size provided. Only full tiles are "
3818  "supported when padding_value is not set");
3819  }
3820  return success();
3821 }
3822 
3823 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
3824 /// Value's to kDynamic, even if they are arith.constant values.
3825 static SmallVector<int64_t>
3827  SmallVector<int64_t> result;
3828  for (auto o : ofrs) {
3829  // Have to do this first, as getConstantIntValue special-cases constants.
3830  if (llvm::dyn_cast_if_present<Value>(o))
3831  result.push_back(ShapedType::kDynamic);
3832  else
3833  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
3834  }
3835  return result;
3836 }
3837 
3838 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
3839 /// the packed type. Having a shared helper helps implement these two methods in
3840 /// a way that ensures that they agree on which dimensions are dynamic.
3842  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
3843  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
3844  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
3845  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3846  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3847  continue;
3848  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3849  resultShape[tiledDim.value()] = ShapedType::kDynamic;
3850  continue;
3851  }
3852  resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
3853  innerTileSizes[tiledDim.index()]);
3854  }
3855 
3856  // Swap tile loops if outer_dims_perm is available.
3857  if (!outerDimsPerm.empty())
3858  applyPermutationToVector(resultShape, outerDimsPerm);
3859 
3860  // Append the inner tile dimensions.
3861  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3862  return resultShape;
3863 }
3864 
3865 SmallVector<OpFoldResult> PackOp::getResultShape(
3866  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
3867  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
3868  ArrayRef<int64_t> outerDimsPerm) {
3869  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
3870 
3871  AffineExpr s0, s1;
3872  bindSymbols(builder.getContext(), s0, s1);
3873  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
3874  for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
3875  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
3876  builder, loc, ceilDivExpr,
3877  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3878  }
3879  if (!outerDimsPerm.empty())
3880  applyPermutationToVector(resultDims, outerDimsPerm);
3881  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3882 
3883  SmallVector<int64_t> resultTypeShape =
3885  asShapeWithAnyValueAsDynamic(innerTileSizes),
3886  innerDimsPos, outerDimsPerm);
3887 
3888  // Fix-up `resultDims` to ensure that they are Value's if and only if the
3889  // result type shape says it's a dynamic dim. This is needed as callers may
3890  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
3891  // dynamic dims returned by that.
3892  for (unsigned i = 0; i < resultDims.size(); ++i) {
3893  if (!ShapedType::isDynamic(resultTypeShape[i]))
3894  continue;
3895  resultDims[i] =
3896  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
3897  }
3898 
3899  return resultDims;
3900 }
3901 
3902 /// Get the expected packed type based on source type, tile factors, position of
3903 /// the inner tiles and permutation of the outer tiled loop.
3904 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
3905  ArrayRef<int64_t> innerTileSizes,
3906  ArrayRef<int64_t> innerDimsPos,
3907  ArrayRef<int64_t> outerDimsPerm) {
3909  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3910  return RankedTensorType::get(resultShape, sourceType.getElementType());
3911 }
3912 
3913 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
3914  ArrayRef<OpFoldResult> innerTileSizes,
3915  ArrayRef<int64_t> innerDimsPos,
3916  ArrayRef<int64_t> outerDimsPerm) {
3917  AffineExpr dim0, dim1;
3918  bindDims(b.getContext(), dim0, dim1);
3919  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
3920  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
3921  {v1, v2});
3922  };
3923 
3924  SmallVector<OpFoldResult> mixedSizes;
3925  for (auto [index, value] : llvm::enumerate(
3926  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
3927  if (ShapedType::isDynamic(value))
3928  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
3929  else
3930  mixedSizes.push_back(b.getIndexAttr(value));
3931  }
3932  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3933  int64_t dimPos = std::get<0>(it);
3934  OpFoldResult tileSize = std::get<1>(it);
3935  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
3936  }
3937  if (!outerDimsPerm.empty())
3938  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3939 
3940  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3941  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3942  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3943 }
3944 
3945 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
3946  ArrayRef<int64_t> innerPermutation,
3947  ArrayRef<int64_t> outerPermutation) {
3948  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
3949  *this, innerPermutation, outerPermutation);
3950  Value transposedDest =
3951  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3952  metadata.innerDimsPos, metadata.outerDimsPerm);
3953  return b.create<PackOp>(loc, getSource(), transposedDest,
3954  metadata.innerDimsPos, metadata.innerTiles,
3955  getPaddingValue(), metadata.outerDimsPerm);
3956 }
3957 
3958 /// Returns true if the tiles and the tiled dims are constant.
3959 template <typename OpTy>
3961  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3962  "applies to only pack or unpack operations");
3963  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3964  ? op.getDestType()
3965  : op.getSourceType();
3966  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
3967  for (auto [dimDest, tile] : llvm::zip(
3968  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3969  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
3970  if (!constTileSize || ShapedType::isDynamic(dimDest))
3971  return false;
3972  }
3973  return true;
3974 }
3975 
3976 Speculation::Speculatability PackOp::getSpeculatability() {
3977  if (getPaddingValue())
3979 
3980  // The verifier rejects already operations if we can statically prove that the
3981  // sizes of the tiles do not divide perfectly the dimension; thus, check only
3982  // to have constant tiles and tiled inner dimensions.
3983  if (!areTilesAndTiledDimsAllConstant(*this))
3985 
3987 }
3988 
3989 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
3990 // dimensions for pack and unpack.
3991 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
3992  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3993  return false;
3994  return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3995 }
3996 
3997 // Return true if pack and unpack have the same tiles.
3998 // Same SSA values or same integer constants.
3999 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4000  auto packTiles = packOp.getMixedTiles();
4001  auto unPackTiles = unPackOp.getMixedTiles();
4002  if (packTiles.size() != unPackTiles.size())
4003  return false;
4004  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4005  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4006  return false;
4007  }
4008  return true;
4009 }
4010 
4011 /// Returns true if the pack op does not need a padding value.
4012 static bool paddingIsNotNeeded(PackOp op) {
4013  auto srcType = op.getSourceType();
4014  if (llvm::any_of(op.getInnerDimsPos(),
4015  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4016  return false;
4017  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4018  return false;
4019  return !PackOp::requirePaddingValue(
4020  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4021  op.getOuterDimsPerm(), op.getMixedTiles());
4022 }
4023 
4024 /// Returns true if the `srcShape` or `destShape` is different from the one in
4025 /// `packOp` and populates each with the inferred static shape.
4026 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4027  SmallVectorImpl<int64_t> &destShape) {
4028  bool changeNeeded = false;
4029  srcShape.assign(packOp.getSourceType().getShape().begin(),
4030  packOp.getSourceType().getShape().end());
4031  destShape.assign(packOp.getDestType().getShape().begin(),
4032  packOp.getDestType().getShape().end());
4033  llvm::SmallSetVector<int64_t, 4> innerDims;
4034  innerDims.insert(packOp.getInnerDimsPos().begin(),
4035  packOp.getInnerDimsPos().end());
4036  SmallVector<int64_t> inverseOuterDimsPerm;
4037  if (!packOp.getOuterDimsPerm().empty())
4038  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4039  int srcRank = packOp.getSourceRank();
4040  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4041  if (innerDims.contains(i))
4042  continue;
4043  int64_t srcPos = i;
4044  int64_t destPos = i;
4045  if (!inverseOuterDimsPerm.empty())
4046  destPos = inverseOuterDimsPerm[srcPos];
4047  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4048  ShapedType::isDynamic(destShape[destPos])) {
4049  continue;
4050  }
4051  int64_t size = srcShape[srcPos];
4052  if (ShapedType::isDynamic(size))
4053  size = destShape[destPos];
4054  srcShape[srcPos] = size;
4055  destShape[destPos] = size;
4056  changeNeeded = true;
4057  }
4058  return changeNeeded;
4059 }
4060 
4061 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4062  // Fold an unpack(pack(x)) to x.
4063  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4064  if (unPackOp.getSourceType() != packOp.getDestType())
4065  return failure();
4066  if (packOp.getPaddingValue() ||
4067  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4068  !haveSameTiles(packOp, unPackOp))
4069  return failure();
4070  rewriter.replaceOp(packOp, unPackOp.getSource());
4071  return success();
4072  }
4073 
4074  // Fold optional PaddingValue operand away if padding is not needed.
4075  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4076  rewriter.startOpModification(packOp);
4077  packOp.getPaddingValueMutable().clear();
4078  rewriter.finalizeOpModification(packOp);
4079  return success();
4080  }
4081 
4082  // Insert tensor.cast ops if static shape inference is available..
4083  SmallVector<int64_t> srcShape, destShape;
4084  if (inferStaticShape(packOp, srcShape, destShape)) {
4085  Location loc = packOp.getLoc();
4086  Value source = packOp.getSource();
4087  if (srcShape != packOp.getSourceType().getShape()) {
4088  auto newSrcType = packOp.getSourceType().clone(srcShape);
4089  source =
4090  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4091  }
4092  Value dest = packOp.getDest();
4093  if (destShape != packOp.getDestType().getShape()) {
4094  auto newDestType = packOp.getDestType().clone(destShape);
4095  dest =
4096  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4097  }
4098  Value newOp = rewriter.create<tensor::PackOp>(
4099  loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4100  packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4101  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4102  packOp, packOp.getResult().getType(), newOp);
4103  return success();
4104  }
4105 
4106  return failure();
4107 }
4108 
4109 template <typename PackOrUnpackOp>
4110 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4111  RankedTensorType packedTensorType) {
4112  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4113  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4114  "Function meant for pack/unpack");
4115  // This is a pad if packing only adds ones and we don't transpose dimensions.
4116 
4117  // Check that we are not transposing any dimensions.
4118  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4119  int64_t numPackedDims = innerDimsPos.size();
4120  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4121  if (orderedDims != innerDimsPos) {
4122  // Dimensions don't happen in order.
4123  return false;
4124  }
4125 
4126  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4127  int64_t packedRank = packedTensorType.getRank();
4128  // At this point we know that we are taking numPackedDims outer
4129  // dimensions and pushing them all the way as the inner most dimensions.
4130  // What's left on the outer most dimensions is, in this order:
4131  // - the factor of the packed dimensions, then
4132  // - the untouched dimensions
4133  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4134  // if all the dimensions that bubble outerward are ones.
4135  // Therefore check that all the dimensions but the numPackedDims inner most
4136  // ones are ones.
4137  return llvm::all_of(
4138  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4139  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4140 }
4141 
4142 bool PackOp::isLikePad() {
4143  auto packedTensorType =
4144  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4145  return isLikePadUnPad(*this, packedTensorType);
4146 }
4147 
4148 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4149  std::optional<Attribute> paddingValue;
4150  if (auto pad = adaptor.getPaddingValue())
4151  paddingValue = pad;
4152  if (OpFoldResult reshapedSource = reshapeConstantSource(
4153  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4154  getDestType(), paddingValue))
4155  return reshapedSource;
4156  return {};
4157 }
4158 
4159 //===----------------------------------------------------------------------===//
4160 // UnPackOp
4161 //===----------------------------------------------------------------------===//
4162 
4163 void UnPackOp::getAsmResultNames(
4164  function_ref<void(Value, StringRef)> setNameFn) {
4165  setNameFn(getResult(), "unpack");
4166 }
4167 
4170  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4171  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4172 }
4173 
4174 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4175  return getDimAndTileMappingImpl(*this);
4176 }
4177 
4178 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4179  return getMixedTilesImpl(*this);
4180 }
4181 
4182 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4183  return getStaticTilesImpl(*this);
4184 }
4185 
4187  return commonVerifierPackAndUnPackOp(*this);
4188 }
4189 
4190 Speculation::Speculatability UnPackOp::getSpeculatability() {
4191  // See PackOp::getSpeculatability.
4192  if (!areTilesAndTiledDimsAllConstant(*this))
4194 
4196 }
4197 
4198 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4199  Value dest, ArrayRef<int64_t> innerDimsPos,
4200  ArrayRef<OpFoldResult> innerTiles,
4201  ArrayRef<int64_t> outerDimsPerm) {
4202  assert(innerDimsPos.size() == innerTiles.size() &&
4203  "number of tile sizes specified must match the specified number of "
4204  "original dimensions to be tiled");
4205  SmallVector<int64_t> staticTileSizes;
4206  SmallVector<Value> dynamicTileSizes;
4207  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4208  build(builder, state, dest.getType(), source, dest,
4209  outerDimsPerm.empty() ? nullptr
4210  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4211  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4212  builder.getDenseI64ArrayAttr(staticTileSizes));
4213 }
4214 
4215 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4216  Value source,
4217  ArrayRef<OpFoldResult> innerTileSizes,
4218  ArrayRef<int64_t> innerDimsPos,
4219  ArrayRef<int64_t> outerDimsPerm) {
4220  AffineExpr sym0, sym1;
4221  bindSymbols(b.getContext(), sym0, sym1);
4222  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4223  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4224  };
4225 
4226  SmallVector<OpFoldResult> mixedSizes;
4227  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4228  for (auto i :
4229  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4230  if (srcType.isDynamicDim(i))
4231  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4232  else
4233  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4234  }
4235  if (!outerDimsPerm.empty()) {
4236  applyPermutationToVector<OpFoldResult>(
4237  mixedSizes, invertPermutationVector(outerDimsPerm));
4238  }
4239 
4240  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4241  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4242 
4243  auto elemType = srcType.getElementType();
4244  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4245 }
4246 
4247 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4248  Value transposedSource,
4249  ArrayRef<int64_t> innerPermutation,
4250  ArrayRef<int64_t> outerPermutation) {
4251  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4252  *this, innerPermutation, outerPermutation);
4253  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4254  metadata.innerDimsPos, metadata.innerTiles,
4255  metadata.outerDimsPerm);
4256 }
4257 
4258 /// Returns true if the `srcShape` or `destShape` is different from the one in
4259 /// `op` and populates each with the inferred static shape.
4260 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4261  SmallVectorImpl<int64_t> &destShape) {
4262  bool changeNeeded = false;
4263  srcShape.assign(op.getSourceType().getShape().begin(),
4264  op.getSourceType().getShape().end());
4265  destShape.assign(op.getDestType().getShape().begin(),
4266  op.getDestType().getShape().end());
4267  llvm::SmallSetVector<int64_t, 4> innerDims;
4268  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4269  SmallVector<int64_t> inverseOuterDimsPerm;
4270  if (!op.getOuterDimsPerm().empty())
4271  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4272  int destRank = op.getDestRank();
4273  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4274  if (innerDims.contains(i))
4275  continue;
4276  int64_t srcPos = i;
4277  int64_t destPos = i;
4278  if (!inverseOuterDimsPerm.empty())
4279  srcPos = inverseOuterDimsPerm[destPos];
4280  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4281  ShapedType::isDynamic(destShape[destPos])) {
4282  continue;
4283  }
4284  int64_t size = srcShape[srcPos];
4285  if (ShapedType::isDynamic(size))
4286  size = destShape[destPos];
4287  srcShape[srcPos] = size;
4288  destShape[destPos] = size;
4289  changeNeeded = true;
4290  }
4291  return changeNeeded;
4292 }
4293 
4294 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4295  PatternRewriter &rewriter) {
4296  /// pack(unpack(x)) -> x
4297  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4298  if (packOp.getDestType() != unPackOp.getSourceType())
4299  return failure();
4300  if (packOp.getPaddingValue() ||
4301  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4302  !haveSameTiles(packOp, unPackOp))
4303  return failure();
4304  rewriter.replaceOp(unPackOp, packOp.getSource());
4305  return success();
4306  }
4307  /// unpack(destinationStyleOp(x)) -> unpack(x)
4308  if (auto dstStyleOp =
4309  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4310  auto destValue = unPackOp.getDest().cast<OpResult>();
4311  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4312  rewriter.modifyOpInPlace(unPackOp,
4313  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4314  return success();
4315  }
4316 
4317  // Insert tensor.cast ops if static shape inference is available..
4318  SmallVector<int64_t> srcShape, destShape;
4319  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4320  Location loc = unPackOp.getLoc();
4321  Value source = unPackOp.getSource();
4322  if (srcShape != unPackOp.getSourceType().getShape()) {
4323  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4324  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4325  unPackOp.getSource());
4326  }
4327  Value dest = unPackOp.getDest();
4328  if (destShape != unPackOp.getDestType().getShape()) {
4329  auto newDestType = unPackOp.getDestType().clone(destShape);
4330  dest =
4331  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4332  }
4333  Value newOp = rewriter.create<tensor::UnPackOp>(
4334  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4335  unPackOp.getOuterDimsPerm());
4336  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4337  unPackOp, unPackOp.getResult().getType(), newOp);
4338  return success();
4339  }
4340 
4341  return failure();
4342 }
4343 
4344 bool UnPackOp::isLikeUnPad() {
4345  RankedTensorType packedTensorType = getSourceType();
4346  return isLikePadUnPad(*this, packedTensorType);
4347 }
4348 
4349 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4350  if (OpFoldResult reshapedSource = reshapeConstantSource(
4351  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4352  getResult().getType()))
4353  return reshapedSource;
4354  return {};
4355 }
4356 
4357 //===----------------------------------------------------------------------===//
4358 // Common Canonicalizers and Folders.
4359 //===----------------------------------------------------------------------===//
4360 
4361 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4362 /// the `tensor.cast` has source that is more static than the consuming op.
4363 ///
4364 /// Example:
4365 /// ```mlir
4366 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4367 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4368 /// ```
4369 ///
4370 /// folds into:
4371 ///
4372 /// ```mlir
4373 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4374 /// ```
4375 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4376 /// can add the pattern to their canonicalizers.
4378  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4380  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4381 
4382  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4383  PatternRewriter &rewriter) const override {
4384  // InsertSliceOp has its own logic about folding tensor.cast ops.
4385  if (isa<InsertSliceOp>(op.getOperation()))
4386  return failure();
4387 
4388  // Exclude DPS ops that are also LoopLike from this interface as they
4389  // might need special handling of attached regions.
4390  if (isa<LoopLikeOpInterface>(op.getOperation()))
4391  return failure();
4392 
4393  // If no operand comes from a tensor::CastOp and can be folded then fail.
4394  bool hasTensorCastOperand =
4395  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4396  if (llvm::isa<BlockArgument>(opOperand.get()))
4397  return false;
4398  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4399  return castOp && canFoldIntoConsumerOp(castOp);
4400  });
4401  if (!hasTensorCastOperand)
4402  return failure();
4403 
4404  SmallVector<Type, 4> newResultTypes;
4405  newResultTypes.reserve(op->getNumResults());
4406  SmallVector<Value, 4> newOperands;
4407  newOperands.reserve(op->getNumOperands());
4408  for (OpOperand &opOperand : op->getOpOperands()) {
4409  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4410  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4411  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4412  if (op.isDpsInit(&opOperand) &&
4413  !llvm::isa<MemRefType>(newOperands.back().getType()))
4414  newResultTypes.push_back(newOperands.back().getType());
4415  }
4416 
4417  // Clone op.
4418  Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
4419  SmallVector<Value, 4> replacements;
4420  replacements.reserve(newOp->getNumResults());
4421  for (auto [oldResult, newResult] :
4422  llvm::zip(op->getResults(), newOp->getResults())) {
4423  if (newResult.getType() != oldResult.getType()) {
4424  replacements.push_back(rewriter.create<tensor::CastOp>(
4425  op->getLoc(), oldResult.getType(), newResult));
4426  } else {
4427  replacements.push_back(newResult);
4428  }
4429  }
4430  rewriter.replaceOp(op, replacements);
4431 
4432  return success();
4433  }
4434 };
4435 
4436 //===----------------------------------------------------------------------===//
4437 // TensorDialect
4438 //===----------------------------------------------------------------------===//
4439 
4440 void TensorDialect::getCanonicalizationPatterns(
4441  RewritePatternSet &results) const {
4443 }
4444 
4445 //===----------------------------------------------------------------------===//
4446 // TableGen'd op method definitions
4447 //===----------------------------------------------------------------------===//
4448 
4449 #define GET_OP_CLASSES
4450 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:3960
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:383
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1288
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:3688
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2053
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3508
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3539
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:4012
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2754
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:3841
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:3826
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3523
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2474
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2326
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2348
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4026
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3569
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3999
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2435
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4110
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3584
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:175
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3496
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3552
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1074
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:2750
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:131
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:3991
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2494
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1675
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:926
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:148
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:249
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:260
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:64
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:354
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2286
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:345
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:315
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2727
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2373
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:119
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:51
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:70
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:369
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:954
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:29
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4378
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4382
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2305
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2306
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2293
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2294
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)