MLIR  20.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpDefinition.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
33 #include <algorithm>
34 #include <optional>
35 
36 using namespace mlir;
37 using namespace mlir::tensor;
38 
39 using llvm::divideCeilSigned;
40 using llvm::divideFloorSigned;
41 using llvm::mod;
42 
43 /// Materialize a single constant operation from a given attribute value with
44 /// the desired resultant type.
46  Attribute value, Type type,
47  Location loc) {
48  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
49  return op;
50  if (complex::ConstantOp::isBuildableWith(value, type))
51  return builder.create<complex::ConstantOp>(loc, type,
52  llvm::cast<ArrayAttr>(value));
53  return nullptr;
54 }
55 
57  int64_t dim) {
58  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
60  if (tensorType.isDynamicDim(dim))
61  return builder.createOrFold<tensor::DimOp>(loc, value, dim);
62 
63  return builder.getIndexAttr(tensorType.getDimSize(dim));
64 }
65 
67  Location loc, Value value) {
68  auto tensorType = llvm::cast<RankedTensorType>(value.getType());
70  for (int64_t i = 0; i < tensorType.getRank(); ++i)
71  result.push_back(getMixedSize(builder, loc, value, i));
72  return result;
73 }
74 
76  OpResult opResult) {
77  auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
78  assert(tensorType && "expected tensor type");
79 
80  // If the op has a destination, it implements DestinationStyleOpInterface and
81  // we can query the destination operand from that interface.
82  auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
83  if (destOp)
84  return destOp.getTiedOpOperand(opResult)->get();
85 
86  // Otherwise, create a new destination tensor with the same shape.
88  b.setInsertionPoint(opResult.getDefiningOp());
89 
90  // Compute sizes.
91  SmallVector<OpFoldResult> mixedSizes;
92  if (!tensorType.hasStaticShape()) {
93  // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
94  ReifiedRankedShapedTypeDims reifiedShapes;
95  if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
96  return failure();
97  mixedSizes = reifiedShapes[opResult.getResultNumber()];
98  } else {
99  // Static shape: Take static sizes directly.
100  for (int64_t sz : tensorType.getShape())
101  mixedSizes.push_back(b.getIndexAttr(sz));
102  }
103 
104  // Create empty tensor.
105  Value emptyTensor =
106  b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
107  return emptyTensor;
108 }
109 
111  Operation *op,
112  SmallVector<Value> &result) {
113  for (OpResult opResult : op->getResults()) {
114  if (llvm::isa<TensorType>(opResult.getType())) {
115  FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
116  if (failed(destination))
117  return failure();
118  result.push_back(*destination);
119  }
120  }
121  return success();
122 }
123 
125  if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126  if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127  return rtp1.getShape() == rtp2.getShape() &&
128  rtp1.getElementType() == rtp2.getElementType();
129  return false;
130  }
131  return tp1 == tp2; // default implementation
132 }
133 
134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
135 /// rank-extending tensor.insert_slice op.
136 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
137  ArrayRef<OpFoldResult> mixedSizes) {
138  llvm::SmallBitVector droppedDims(mixedSizes.size());
139  int64_t shapePos = reducedShape.size() - 1;
140 
141  for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142  size_t idx = mixedSizes.size() - size.index() - 1;
143  // Rank-reduced dims must have a static unit dimension.
144  bool isStaticUnitSize =
145  size.value().is<Attribute>() &&
146  llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
147 
148  if (shapePos < 0) {
149  // There are no more dims in the reduced shape. All remaining sizes must
150  // be rank-reduced dims.
151  assert(isStaticUnitSize && "expected unit dim");
152  droppedDims.set(idx);
153  continue;
154  }
155 
156  // Dim is preserved if the size is not a static 1.
157  if (!isStaticUnitSize) {
158  --shapePos;
159  continue;
160  }
161 
162  // Dim is preserved if the reduced shape dim is also 1.
163  if (reducedShape[shapePos] == 1) {
164  --shapePos;
165  continue;
166  }
167 
168  // Otherwise: Dim is dropped.
169  droppedDims.set(idx);
170  }
171 
172  assert(shapePos < 0 && "dimension mismatch");
173  return droppedDims;
174 }
175 
176 /// Given a ranked tensor type and a range of values that defines its dynamic
177 /// dimension sizes, turn all dynamic sizes that have a constant value into
178 /// static dimension sizes.
179 static RankedTensorType
180 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
181  SmallVector<Value> &foldedDynamicSizes) {
182  SmallVector<int64_t> staticShape(type.getShape());
183  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184  "incorrect number of dynamic sizes");
185 
186  // Compute new static and dynamic sizes.
187  unsigned ctr = 0;
188  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189  if (type.isDynamicDim(i)) {
190  Value dynamicSize = dynamicSizes[ctr++];
191  std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
192  if (cst.has_value()) {
193  // Dynamic size must be non-negative.
194  if (cst.value() < 0) {
195  foldedDynamicSizes.push_back(dynamicSize);
196  continue;
197  }
198  staticShape[i] = *cst;
199  } else {
200  foldedDynamicSizes.push_back(dynamicSize);
201  }
202  }
203  }
204 
205  return RankedTensorType::get(staticShape, type.getElementType(),
206  type.getEncoding());
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // BitcastOp
211 //===----------------------------------------------------------------------===//
212 
213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
214  if (inputs.size() != 1 || outputs.size() != 1)
215  return false;
216  Type a = inputs.front(), b = outputs.front();
217  auto aT = dyn_cast<TensorType>(a);
218  auto bT = dyn_cast<TensorType>(b);
219  if (!aT || !bT)
220  return false;
221 
222  if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
223  return false;
224 
225  return succeeded(verifyCompatibleShape(aT, bT));
226 }
227 
228 namespace {
229 
230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
231 /// operation.
232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
234 
235  LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236  PatternRewriter &rewriter) const final {
237  auto tensorBitcastOperand =
238  tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239  if (!tensorBitcastOperand)
240  return failure();
241 
242  auto resultType = cast<TensorType>(tensorBitcast.getType());
243  rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244  tensorBitcastOperand.getOperand());
245  return success();
246  }
247 };
248 
249 } // namespace
250 
251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
252  MLIRContext *context) {
253  results.add<ChainedTensorBitcast>(context);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // CastOp
258 //===----------------------------------------------------------------------===//
259 
260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
261  setNameFn(getResult(), "cast");
262 }
263 
264 /// Returns true if `target` is a ranked tensor type that preserves static
265 /// information available in the `source` ranked tensor type.
267  auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268  auto targetType = llvm::dyn_cast<RankedTensorType>(target);
269 
270  // Requires RankedTensorType.
271  if (!sourceType || !targetType)
272  return false;
273 
274  // Requires same elemental type.
275  if (sourceType.getElementType() != targetType.getElementType())
276  return false;
277 
278  // Requires same rank.
279  if (sourceType.getRank() != targetType.getRank())
280  return false;
281 
282  // Requires same encoding.
283  if (sourceType.getEncoding() != targetType.getEncoding())
284  return false;
285 
286  // If cast is towards more static sizes along any dimension, don't fold.
287  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288  if (!ShapedType::isDynamic(std::get<0>(t)) &&
289  ShapedType::isDynamic(std::get<1>(t)))
290  return false;
291  }
292 
293  return true;
294 }
295 
296 /// Determines whether tensor::CastOp casts to a more dynamic version of the
297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
298 /// implement canonicalization patterns for ops in different dialects that may
299 /// consume the results of tensor.cast operations. Such foldable tensor.cast
300 /// operations are typically inserted as `slice` ops and are canonicalized,
301 /// to preserve the type compatibility of their uses.
302 ///
303 /// Returns true when all conditions are met:
304 /// 1. source and result are ranked tensors with same element type and rank.
305 /// 2. the tensor type has more static information than the result
306 ///
307 /// Example:
308 /// ```mlir
309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
310 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
311 /// ```
312 ///
313 /// folds into:
314 ///
315 /// ```mlir
316 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
317 /// ```
319  if (!castOp)
320  return false;
321 
322  // Can fold if the source of cast has at least as much static information as
323  // its results.
324  return preservesStaticInformation(castOp.getType(),
325  castOp.getSource().getType());
326 }
327 
328 /// Determines whether the tensor::CastOp casts to a more static version of the
329 /// source tensor. This is useful to fold into a producing op and implement
330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
331 /// being from different dialects. Returns true when all conditions are met:
332 /// 1. source and result and ranked tensors with same element type and rank.
333 /// 2. the result type has more static information than the source.
334 ///
335 /// Example:
336 /// ```mlir
337 /// %1 = producer ... : tensor<?x?xf32>
338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
339 /// ```
340 ///
341 /// can be canonicalized to :
342 ///
343 /// ```mlir
344 /// %2 = producer ... : tensor<8x16xf32>
345 /// ```
346 /// Not all ops might be canonicalizable this way, but for those that can be,
347 /// this method provides a check that it is worth doing the canonicalization.
349  if (!castOp)
350  return false;
351  return preservesStaticInformation(castOp.getSource().getType(),
352  castOp.getType());
353 }
354 
355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
356 /// that can be folded.
358  bool folded = false;
359  for (OpOperand &operand : op->getOpOperands()) {
360  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
361  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
362  operand.set(castOp.getOperand());
363  folded = true;
364  }
365  }
366  return success(folded);
367 }
368 
369 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
370  if (inputs.size() != 1 || outputs.size() != 1)
371  return false;
372  Type a = inputs.front(), b = outputs.front();
373  auto aT = llvm::dyn_cast<TensorType>(a);
374  auto bT = llvm::dyn_cast<TensorType>(b);
375  if (!aT || !bT)
376  return false;
377 
378  if (aT.getElementType() != bT.getElementType())
379  return false;
380 
381  return succeeded(verifyCompatibleShape(aT, bT));
382 }
383 
384 /// Compute a TensorType that has the joined shape knowledge of the two
385 /// given TensorTypes. The element types need to match.
387  assert(one.getElementType() == two.getElementType());
388 
389  if (!one.hasRank())
390  return two;
391  if (!two.hasRank())
392  return one;
393 
394  int64_t rank = one.getRank();
395  if (rank != two.getRank())
396  return {};
397 
399  join.reserve(rank);
400  for (int64_t i = 0; i < rank; ++i) {
401  if (one.isDynamicDim(i)) {
402  join.push_back(two.getDimSize(i));
403  continue;
404  }
405  if (two.isDynamicDim(i)) {
406  join.push_back(one.getDimSize(i));
407  continue;
408  }
409  if (one.getDimSize(i) != two.getDimSize(i))
410  return {};
411  join.push_back(one.getDimSize(i));
412  }
413  return RankedTensorType::get(join, one.getElementType());
414 }
415 
416 namespace {
417 
418 /// Replaces chains of two tensor.cast operations by a single tensor.cast
419 /// operation if doing so does not remove runtime constraints.
420 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
422 
423  LogicalResult matchAndRewrite(CastOp tensorCast,
424  PatternRewriter &rewriter) const final {
425  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
426 
427  if (!tensorCastOperand)
428  return failure();
429 
430  auto sourceType =
431  llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432  auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433  auto resultType = llvm::cast<TensorType>(tensorCast.getType());
434 
435  // We can remove the intermediate cast if joining all three produces the
436  // same result as just joining the source and result shapes.
437  auto firstJoin =
438  joinShapes(joinShapes(sourceType, intermediateType), resultType);
439 
440  // The join might not exist if the cast sequence would fail at runtime.
441  if (!firstJoin)
442  return failure();
443 
444  // The newJoin always exists if the above join exists, it might just contain
445  // less information. If so, we cannot drop the intermediate cast, as doing
446  // so would remove runtime checks.
447  auto newJoin = joinShapes(sourceType, resultType);
448  if (firstJoin != newJoin)
449  return failure();
450 
451  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452  tensorCastOperand.getOperand());
453  return success();
454  }
455 };
456 
457 /// Fold tensor.cast into tesor.extract_slice producer.
458 /// Example:
459 /// ```
460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
461 /// tensor<128x512xf32> to tensor<?x512xf32>
462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
463 /// ```
464 /// ->
465 /// ```
466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
467 /// tensor<128x512xf32> to tensor<16x512xf32>
468 /// ```
469 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
471 
472  LogicalResult matchAndRewrite(CastOp tensorCast,
473  PatternRewriter &rewriter) const final {
474  auto extractOperand =
475  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
476 
477  // Cannot fold cast to unranked tensor.
478  auto rankedResultType =
479  llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480  if (!rankedResultType)
481  return failure();
482 
483  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
484  rankedResultType.getShape() ==
485  llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
486  .getShape())
487  return failure();
488 
489  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
490  auto dimMask = computeRankReductionMask(
491  extractOperand.getStaticSizes(), extractOperand.getType().getShape());
492  size_t dimIndex = 0;
493  for (size_t i = 0, e = sizes.size(); i < e; i++) {
494  if (dimMask && dimMask->count(i))
495  continue;
496  int64_t dim = rankedResultType.getShape()[dimIndex++];
497  if (ShapedType::isDynamic(dim))
498  continue;
499  sizes[i] = rewriter.getIndexAttr(dim);
500  }
501 
502  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503  tensorCast, rankedResultType, extractOperand.getSource(),
504  extractOperand.getMixedOffsets(), sizes,
505  extractOperand.getMixedStrides());
506  return success();
507  }
508 };
509 
510 } // namespace
511 
512 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
513  MLIRContext *context) {
514  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // ConcatOp
519 //===----------------------------------------------------------------------===//
520 
521 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
522  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
523  auto tensorTypes =
524  llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
525  return llvm::cast<RankedTensorType>(type);
526  }));
527  int64_t concatRank = tensorTypes[0].getRank();
528 
529  // The concatenation dim must be in the range [0, rank).
530  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
531 
532  SmallVector<int64_t> sizes(concatRank);
533  for (int64_t i = 0, e = concatRank; i < e; ++i) {
534  if (i == dim)
535  continue;
536  SaturatedInteger size;
537  for (auto tensorType : tensorTypes)
538  size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
539  sizes[i] = size.asInteger();
540  }
541  auto concatSize = SaturatedInteger::wrap(0);
542  for (auto tensorType : tensorTypes)
543  concatSize =
544  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
545  sizes[dim] = concatSize.asInteger();
546  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
547 }
548 
549 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
550  ValueRange inputs) {
551  FailureOr<RankedTensorType> resultType =
552  inferResultType(dim, inputs.getTypes());
553  assert(succeeded(resultType) && "failed to infer concatenation result type");
554  build(builder, result, *resultType, dim, inputs);
555 }
556 
557 LogicalResult ConcatOp::verify() {
558  if (getInputs().size() < 1)
559  return emitOpError("requires at least one input");
560 
562  for (auto input : getInputs())
563  inputTypes.push_back(cast<RankedTensorType>(input.getType()));
564 
565  RankedTensorType resultType = getResultType();
566  int64_t resultRank = getRank();
567  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568  return type.getRank() != resultRank;
569  }))
570  return emitOpError("rank of concatenated inputs must match result rank");
571 
572  Type resultElementType = resultType.getElementType();
573  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574  return type.getElementType() != resultElementType;
575  }))
576  return emitOpError("inputs and result element type must match");
577 
578  int64_t dim = getDim();
579  if (dim >= resultRank)
580  return emitOpError("concatenation dim must be less than the tensor rank");
581 
582  SmallVector<int64_t> sizes(resultRank);
583  for (int64_t i = 0, e = resultRank; i < e; ++i) {
584  if (i == dim)
585  continue;
586  SaturatedInteger size;
587  for (auto tensorType : inputTypes) {
588  FailureOr<SaturatedInteger> maybeSize =
589  size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
590  if (failed(maybeSize))
591  return emitOpError("static concatenation size mismatch along ")
592  << "non-concatenated dimension " << i;
593  size = *maybeSize;
594  }
595  sizes[i] = size.asInteger();
596  }
597  auto concatSize = SaturatedInteger::wrap(0);
598  for (auto tensorType : inputTypes)
599  concatSize =
600  concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
601  sizes[dim] = concatSize.asInteger();
602  auto inferredResultType =
603  RankedTensorType::get(sizes, inputTypes[0].getElementType());
604 
605  for (auto [inferredSize, actualSize] :
606  llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607  bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608  ShapedType::isDynamic(actualSize);
609  if (!hasDynamic && inferredSize != actualSize)
610  return emitOpError("result type ")
611  << resultType << "does not match inferred shape "
612  << inferredResultType << " static sizes";
613  }
614 
615  return success();
616 }
617 
618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
619  size_t numInputs = getInputs().size();
620  uint64_t concatDim = getDim();
621 
623  inputShapes.reserve(numInputs);
624  SmallVector<OpFoldResult> concatOffsets;
625  concatOffsets.reserve(numInputs);
626  SmallVector<OpFoldResult> outputShape;
627 
628  AffineExpr addExpr =
629  builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
630  OpFoldResult zero = builder.getIndexAttr(0);
631  Location loc = getLoc();
632  for (auto [index, input] : llvm::enumerate(getInputs())) {
633  SmallVector<OpFoldResult> inputShape =
634  tensor::getMixedSizes(builder, input.getLoc(), input);
635  if (index == 0) {
636  outputShape = inputShape;
637  concatOffsets.push_back(zero);
638  } else {
639  concatOffsets.push_back(outputShape[concatDim]);
640  outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
641  builder, loc, addExpr,
642  {outputShape[concatDim], inputShape[concatDim]});
643  }
644  inputShapes.emplace_back(std::move(inputShape));
645  }
646 
647  Value replacement = builder.create<tensor::EmptyOp>(
648  loc, outputShape, getType().getElementType());
649 
650  int64_t rank = getType().getRank();
651  OpFoldResult one = builder.getIndexAttr(1);
652  SmallVector<OpFoldResult> strides(rank, one);
653  SmallVector<OpFoldResult> offsets(rank, zero);
654  for (auto [index, input] : llvm::enumerate(getInputs())) {
655  offsets[concatDim] = concatOffsets[index];
656  auto insertSlice = builder.create<tensor::InsertSliceOp>(
657  loc, input, replacement, offsets, inputShapes[index], strides);
658  replacement = insertSlice.getResult();
659  }
660  if (replacement.getType() != getType()) {
661  replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
662  }
663  return SmallVector<Value>{replacement};
664 }
665 
666 LogicalResult
668  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
669  ValueRange inputs = getInputs();
670  int64_t dim = getDim();
671  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
672 
673  Value init = inputs[0];
674  int64_t rank = getType().getRank();
675 
676  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
677 
678  // Pre-populate the result sizes with as much static information as possible
679  // from the given result type, as well as the inferred result type, otherwise
680  // use the dim sizes from the first input.
681  for (int64_t i = 0; i < rank; ++i) {
682  if (i == dim)
683  continue;
684  if (!getType().isDynamicDim(i)) {
685  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
686  } else if (!inferredResultType.isDynamicDim(i)) {
687  reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
688  builder, getLoc(),
689  builder.getIndexAttr(inferredResultType.getDimSize(i)));
690  } else {
691  reifiedReturnShapes[0][i] =
692  builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
693  }
694  }
695 
696  if (getType().isDynamicDim(dim)) {
697  // Take the sum of the input sizes along the concatenated dim.
698  AffineExpr sum = builder.getAffineDimExpr(0);
699  SmallVector<OpFoldResult> sizes = {
700  builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
701  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
702  sum = sum + builder.getAffineDimExpr(idx + 1);
703  sizes.push_back(
704  builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
705  }
706  reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
707  builder, getLoc(),
708  affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
709  } else {
710  // If the result shape is static along the concatenated dim, use the static
711  // shape.
712  reifiedReturnShapes[0][dim] =
713  builder.getIndexAttr(getType().getDimSize(dim));
714  }
715  return success();
716 }
717 
718 void ConcatOp::getAsmResultNames(
719  function_ref<void(Value, StringRef)> setNameFn) {
720  setNameFn(getResult(), "concat");
721 }
722 
723 OpFoldResult ConcatOp::fold(FoldAdaptor) {
724  ValueRange inputs = getInputs();
725  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
726  return inputs[0];
727  return {};
728 }
729 
730 namespace {
731 /// Fold a concat op with a single input to a cast.
732 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
734 
735  LogicalResult matchAndRewrite(ConcatOp concatOp,
736  PatternRewriter &rewriter) const override {
737  if (concatOp.getInputs().size() != 1)
738  return failure();
739  rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
740  concatOp.getInputs()[0]);
741  return success();
742  }
743 };
744 } // namespace
745 
746 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
747  MLIRContext *context) {
748  results.add<SingleInputConcatOp>(context);
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // DimOp
753 //===----------------------------------------------------------------------===//
754 
755 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
756  setNameFn(getResult(), "dim");
757 }
758 
759 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
760  int64_t index) {
761  auto loc = result.location;
762  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
763  build(builder, result, source, indexValue);
764 }
765 
766 std::optional<int64_t> DimOp::getConstantIndex() {
767  return getConstantIntValue(getIndex());
768 }
769 
770 Speculation::Speculatability DimOp::getSpeculatability() {
771  auto constantIndex = getConstantIndex();
772  if (!constantIndex)
774 
775  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
776  if (!rankedSourceType)
778 
779  if (rankedSourceType.getRank() <= constantIndex)
781 
783 }
784 
785 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
786  // All forms of folding require a known index.
787  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
788  if (!index)
789  return {};
790 
791  // Folding for unranked types (UnrankedTensorType) is not supported.
792  auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
793  if (!tensorType)
794  return {};
795 
796  // Out of bound indices produce undefined behavior but are still valid IR.
797  // Don't choke on them.
798  int64_t indexVal = index.getInt();
799  if (indexVal < 0 || indexVal >= tensorType.getRank())
800  return {};
801 
802  // Fold if the shape extent along the given index is known.
803  if (!tensorType.isDynamicDim(index.getInt())) {
804  Builder builder(getContext());
805  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
806  }
807 
808  Operation *definingOp = getSource().getDefiningOp();
809 
810  // Fold dim to the operand of tensor.generate.
811  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
812  auto resultType =
813  llvm::cast<RankedTensorType>(fromElements.getResult().getType());
814  // The case where the type encodes the size of the dimension is handled
815  // above.
816  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
817 
818  // Find the operand of the fromElements that corresponds to this index.
819  auto dynExtents = fromElements.getDynamicExtents().begin();
820  for (auto dim : resultType.getShape().take_front(index.getInt()))
821  if (ShapedType::isDynamic(dim))
822  dynExtents++;
823 
824  return Value{*dynExtents};
825  }
826 
827  // The size at the given index is now known to be a dynamic size.
828  unsigned unsignedIndex = index.getValue().getZExtValue();
829 
830  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
831  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
832  // `resolve-shaped-type-result-dims` pass.
833  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
834  sliceOp.isDynamicSize(unsignedIndex)) {
835  return {sliceOp.getDynamicSize(unsignedIndex)};
836  }
837  }
838 
839  // dim(cast) -> dim
840  if (succeeded(foldTensorCast(*this)))
841  return getResult();
842 
843  return {};
844 }
845 
846 namespace {
847 /// Fold dim of a cast into the dim of the source of the tensor cast.
848 struct DimOfCastOp : public OpRewritePattern<DimOp> {
850 
851  LogicalResult matchAndRewrite(DimOp dimOp,
852  PatternRewriter &rewriter) const override {
853  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
854  if (!castOp)
855  return failure();
856  Value newSource = castOp.getOperand();
857  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
858  return success();
859  }
860 };
861 
862 /// Fold dim of a destination passing style op into the dim of the corresponding
863 /// init.
864 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
866 
867  LogicalResult matchAndRewrite(DimOp dimOp,
868  PatternRewriter &rewriter) const override {
869  auto source = dimOp.getSource();
870  auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
871  if (!destOp)
872  return failure();
873 
874  auto resultIndex = cast<OpResult>(source).getResultNumber();
875  auto *initOperand = destOp.getDpsInitOperand(resultIndex);
876 
877  rewriter.modifyOpInPlace(
878  dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
879  return success();
880  }
881 };
882 
883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape
884 /// operand.
885 struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
887 
888  LogicalResult matchAndRewrite(DimOp dim,
889  PatternRewriter &rewriter) const override {
890  auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
891 
892  if (!reshape)
893  return failure();
894 
895  // Since tensors are immutable we don't need to worry about where to place
896  // the extract call
897  rewriter.setInsertionPointAfter(dim);
898  Location loc = dim.getLoc();
899  Value extract =
900  rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
901  if (extract.getType() != dim.getType())
902  extract =
903  rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
904  rewriter.replaceOp(dim, extract);
905  return success();
906  }
907 };
908 } // namespace
909 
910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
911  MLIRContext *context) {
912  results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
913 }
914 
915 //===----------------------------------------------------------------------===//
916 // EmptyOp
917 //===----------------------------------------------------------------------===//
918 
919 void EmptyOp::build(OpBuilder &builder, OperationState &result,
920  ArrayRef<int64_t> staticShape, Type elementType,
921  Attribute encoding) {
922  assert(all_of(staticShape,
923  [](int64_t sz) { return !ShapedType::isDynamic(sz); }) &&
924  "expected only static sizes");
925  build(builder, result, staticShape, elementType, ValueRange{}, encoding);
926 }
927 
928 void EmptyOp::build(OpBuilder &builder, OperationState &result,
929  ArrayRef<int64_t> staticShape, Type elementType,
930  ValueRange dynamicSizes, Attribute encoding) {
931  auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
932  build(builder, result, tensorType, dynamicSizes);
933 }
934 
935 void EmptyOp::build(OpBuilder &builder, OperationState &result,
936  ArrayRef<OpFoldResult> sizes, Type elementType,
937  Attribute encoding) {
938  SmallVector<int64_t> staticShape;
939  SmallVector<Value> dynamicSizes;
940  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
941  build(builder, result, staticShape, elementType, dynamicSizes, encoding);
942 }
943 
944 LogicalResult EmptyOp::verify() {
945  if (getType().getNumDynamicDims() != getDynamicSizes().size())
946  return emitOpError("incorrect number of dynamic sizes, has ")
947  << getDynamicSizes().size() << ", expected "
948  << getType().getNumDynamicDims();
949  return success();
950 }
951 
952 LogicalResult
954  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
955  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
956  unsigned ctr = 0;
957  for (int64_t i = 0; i < getType().getRank(); ++i) {
958  if (getType().isDynamicDim(i)) {
959  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
960  } else {
961  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
962  }
963  }
964  return success();
965 }
966 
967 Value EmptyOp::getDynamicSize(unsigned idx) {
968  assert(getType().isDynamicDim(idx) && "expected dynamic dim");
969  unsigned ctr = 0;
970  for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
971  if (getType().isDynamicDim(i))
972  ++ctr;
973  return getDynamicSizes()[ctr];
974 }
975 
978  unsigned ctr = 0;
979  OpBuilder b(getContext());
980  for (int64_t i = 0; i < getType().getRank(); ++i) {
981  if (getType().isDynamicDim(i)) {
982  result.push_back(getDynamicSizes()[ctr++]);
983  } else {
984  result.push_back(b.getIndexAttr(getType().getShape()[i]));
985  }
986  }
987  return result;
988 }
989 
990 namespace {
991 /// Change the type of the result of a `tensor.empty` by making the result
992 /// type statically sized along dimensions that in the original operation were
993 /// defined as dynamic, but the size was defined using a `constant` op. For
994 /// example
995 ///
996 /// %c5 = arith.constant 5: index
997 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
998 ///
999 /// to
1000 ///
1001 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1002 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1004 
1005  LogicalResult matchAndRewrite(EmptyOp op,
1006  PatternRewriter &rewriter) const override {
1007  SmallVector<Value> foldedDynamicSizes;
1008  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1009  op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1010 
1011  // Stop here if no dynamic size was promoted to static.
1012  if (foldedTensorType == op.getType())
1013  return failure();
1014 
1015  auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1016  foldedDynamicSizes);
1017  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1018  return success();
1019  }
1020 };
1021 
1022 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1024 
1025  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1026  PatternRewriter &rewriter) const override {
1027  std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1028  auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1029  if (!emptyTensorOp || !maybeConstantIndex)
1030  return failure();
1031  auto emptyTensorType = emptyTensorOp.getType();
1032  if (*maybeConstantIndex < 0 ||
1033  *maybeConstantIndex >= emptyTensorType.getRank() ||
1034  !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1035  return failure();
1036  rewriter.replaceOp(dimOp,
1037  emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1038  return success();
1039  }
1040 };
1041 
1042 /// Canonicalize
1043 ///
1044 /// ```mlir
1045 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1046 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1047 /// ```
1048 ///
1049 /// into
1050 ///
1051 /// ```mlir
1052 /// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1053 /// ```
1054 ///
1055 /// This assumes the input program is correct in terms of its shape. So it is
1056 /// safe to assume that `%d0` is in fact 4.
1057 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1059 
1060  LogicalResult matchAndRewrite(CastOp castOp,
1061  PatternRewriter &rewriter) const override {
1062  if (!canFoldIntoProducerOp(castOp))
1063  return failure();
1064  auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1065  if (!producer)
1066  return failure();
1067 
1068  auto resultType =
1069  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1070  ArrayRef<int64_t> resultShape = resultType.getShape();
1071  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1072  SmallVector<OpFoldResult> newMixedSizes;
1073  newMixedSizes.reserve(currMixedSizes.size());
1074  assert(resultShape.size() == currMixedSizes.size() &&
1075  "mismatch in result shape and sizes of empty op");
1076  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1077  int64_t newDim = std::get<0>(it);
1078  OpFoldResult currDim = std::get<1>(it);
1079  // Case 1: The empty tensor dim is static. Check that the tensor cast
1080  // result dim matches.
1081  if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1082  if (ShapedType::isDynamic(newDim) ||
1083  newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1084  // Something is off, the cast result shape cannot be more dynamic
1085  // than the empty tensor result shape (enforced by
1086  // `canFoldIntoProducer`). Abort for now.
1087  return rewriter.notifyMatchFailure(
1088  producer, "mismatch in static value of shape of empty tensor "
1089  "result and cast result");
1090  }
1091  newMixedSizes.push_back(attr);
1092  continue;
1093  }
1094 
1095  // Case 2 : The tensor cast shape is static, but empty tensor result
1096  // shape is dynamic.
1097  if (!ShapedType::isDynamic(newDim)) {
1098  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1099  continue;
1100  }
1101 
1102  // Case 3 : The tensor cast shape is dynamic and empty tensor result
1103  // shape is dynamic. Use the dynamic value from the empty tensor op.
1104  newMixedSizes.push_back(currDim);
1105  }
1106 
1107  // TODO: Do not drop tensor encoding.
1108  rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1109  resultType.getElementType());
1110  return success();
1111  }
1112 };
1113 
1114 } // namespace
1115 
1116 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1117  MLIRContext *context) {
1118  results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1119  ReplaceEmptyTensorStaticShapeDims>(context);
1120 }
1121 
1122 /// Try to remove a tensor operation if it would only reshape a constant.
1123 /// Removes the op and replaces the constant with a new constant of the result
1124 /// shape. When an optional cst attribute is passed, it is reshaped only if the
1125 /// splat value matches the value in the attribute.
1126 static OpFoldResult
1128  std::optional<Attribute> cst = std::nullopt) {
1129  if (source && source.isSplat() && result.hasStaticShape() &&
1130  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
1131  return source.resizeSplat(result);
1132 
1133  return {};
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // ExtractOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 namespace {
1141 
1142 /// Canonicalizes the pattern of the form
1143 ///
1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1146 ///
1147 /// to
1148 ///
1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1150 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1152 
1153  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1154  PatternRewriter &rewriter) const final {
1155  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1156  if (!tensorCast)
1157  return failure();
1158  if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1159  return failure();
1160  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1161  extract, tensorCast.getSource(), extract.getIndices());
1162  return success();
1163  }
1164 };
1165 
1166 } // namespace
1167 
1168 void ExtractOp::getAsmResultNames(
1169  function_ref<void(Value, StringRef)> setNameFn) {
1170  setNameFn(getResult(), "extracted");
1171 }
1172 
1173 LogicalResult ExtractOp::verify() {
1174  // Verify the # indices match if we have a ranked type.
1175  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1176  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1177  return emitOpError("incorrect number of indices for extract_element");
1178  return success();
1179 }
1180 
1181 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1182  if (Attribute tensor = adaptor.getTensor()) {
1183  // If this is a splat elements attribute, simply return the value.
1184  // All of the elements of a splat attribute are the same.
1185  if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1186  return splatTensor.getSplatValue<Attribute>();
1187 
1188  // If this is a dense resource elements attribute, return.
1189  if (isa<DenseResourceElementsAttr>(tensor))
1190  return {};
1191  }
1192 
1193  // Collect the constant indices into the tensor.
1194  SmallVector<uint64_t, 8> indices;
1195  for (Attribute indice : adaptor.getIndices()) {
1196  if (!indice || !llvm::isa<IntegerAttr>(indice))
1197  return {};
1198  indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1199  }
1200 
1201  // Fold extract(from_elements(...)).
1202  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1203  auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1204  auto rank = tensorType.getRank();
1205  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1206  "rank mismatch");
1207  int flatIndex = 0;
1208  int stride = 1;
1209  for (int i = rank - 1; i >= 0; --i) {
1210  flatIndex += indices[i] * stride;
1211  stride *= tensorType.getDimSize(i);
1212  }
1213  // Prevent out of bounds accesses. This can happen in invalid code that
1214  // will never execute.
1215  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1216  flatIndex < 0)
1217  return {};
1218  return fromElementsOp.getElements()[flatIndex];
1219  }
1220 
1221  // If this is an elements attribute, query the value at the given indices.
1222  if (Attribute tensor = adaptor.getTensor()) {
1223  auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1224  if (elementsAttr && elementsAttr.isValidIndex(indices))
1225  return elementsAttr.getValues<Attribute>()[indices];
1226  }
1227 
1228  return {};
1229 }
1230 
1231 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1232  MLIRContext *context) {
1233  results.add<ExtractFromTensorCast>(context);
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // FromElementsOp
1238 //===----------------------------------------------------------------------===//
1239 
1240 void FromElementsOp::getAsmResultNames(
1241  function_ref<void(Value, StringRef)> setNameFn) {
1242  setNameFn(getResult(), "from_elements");
1243 }
1244 
1245 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1246  ValueRange elements) {
1247  assert(!elements.empty() && "expected at least one element");
1248  Type resultType = RankedTensorType::get(
1249  {static_cast<int64_t>(elements.size())}, elements.front().getType());
1250  build(builder, result, resultType, elements);
1251 }
1252 
1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1254  if (!llvm::is_contained(adaptor.getElements(), nullptr))
1255  return DenseElementsAttr::get(getType(), adaptor.getElements());
1256  return {};
1257 }
1258 
1259 namespace {
1260 
1261 // Pushes the index_casts that occur before extractions to after the extract.
1262 // This minimizes type conversion in some cases and enables the extract
1263 // canonicalizer. This changes:
1264 //
1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
1267 //
1268 // to the following:
1269 //
1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1271 // %cast = arith.index_cast %extract : i32 to index
1272 //
1273 // to just %element.
1274 //
1275 // Consider expanding this to a template and handle all tensor cast
1276 // operations.
1277 struct ExtractElementFromIndexCast
1278  : public OpRewritePattern<tensor::ExtractOp> {
1280 
1281  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1282  PatternRewriter &rewriter) const final {
1283  Location loc = extract.getLoc();
1284  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1285  if (!indexCast)
1286  return failure();
1287 
1288  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1289 
1290  auto newExtract = rewriter.create<tensor::ExtractOp>(
1291  loc, elementTy, indexCast.getIn(), extract.getIndices());
1292 
1293  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1294  newExtract);
1295 
1296  return success();
1297  }
1298 };
1299 
1300 } // namespace
1301 
1302 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1303  MLIRContext *context) {
1304  results.add<ExtractElementFromIndexCast>(context);
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // GatherOp
1309 //===----------------------------------------------------------------------===//
1310 
1311 void GatherOp::getAsmResultNames(
1312  function_ref<void(Value, StringRef)> setNameFn) {
1313  setNameFn(getResult(), "gather");
1314 }
1315 
1316 /// Return the inferred result type for a gatherOp where:
1317 /// - sourceType is the type of the source tensor gathered from
1318 /// - indicesType is the type of the indices used to gather
1319 /// - gatherDims are the dims along which the gather occurs.
1320 /// Return a full rank or ranked-reduced variant of the type depending on
1321 /// the value of rankReduced.
1322 ///
1323 /// The leading dimensions of the index tensor give the result tensor its
1324 /// leading dimensions.
1325 /// The trailing dimensions of the result tensor are obtained from the source
1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if
1327 /// rankedReduced is false), or skipping them (otherwise).
1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1329  RankedTensorType indicesType,
1330  ArrayRef<int64_t> gatherDims,
1331  bool rankReduced) {
1332  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1333  resultShape.reserve(resultShape.size() + sourceType.getRank());
1334  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1335  if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1336  if (!rankReduced)
1337  resultShape.push_back(1);
1338  continue;
1339  }
1340  resultShape.push_back(sourceType.getDimSize(idx));
1341  }
1342  return RankedTensorType::Builder(sourceType).setShape(resultShape);
1343 }
1344 
1345 static LogicalResult
1347  ArrayRef<int64_t> indices, int64_t rank,
1348  StringRef gatherOrScatter, StringRef sourceOrDest) {
1349  if (dims.empty())
1350  return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
1351 
1352  int64_t numGatherDims = dims.size();
1353  if (numGatherDims > rank)
1354  return op->emitOpError(gatherOrScatter)
1355  << "_dims overflow " << sourceOrDest << " rank";
1356  if (indices.empty() || indices.back() != numGatherDims)
1357  return op->emitOpError(gatherOrScatter)
1358  << "_dims length must match the size of last dimension of indices";
1359  for (int64_t val : dims) {
1360  if (val < 0)
1361  return op->emitOpError(gatherOrScatter)
1362  << "_dims value must be non-negative";
1363  if (val >= rank)
1364  return op->emitOpError(gatherOrScatter)
1365  << "_dims value must be smaller than " << sourceOrDest << " rank";
1366  }
1367  for (int64_t i = 1; i < numGatherDims; ++i) {
1368  if (dims[i - 1] >= dims[i])
1369  return op->emitOpError(gatherOrScatter)
1370  << "_dims values must be strictly increasing";
1371  }
1372  return success();
1373 }
1374 
1375 LogicalResult GatherOp::verify() {
1376  int64_t sourceRank = getSourceType().getRank();
1377  ArrayRef<int64_t> gatherDims = getGatherDims();
1378  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1379  getIndicesType().getShape(), sourceRank,
1380  "gather", "source")))
1381  return failure();
1382 
1383  RankedTensorType expectedResultType = GatherOp::inferResultType(
1384  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1385  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1386  getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1387  if (getResultType() != expectedResultType &&
1388  getResultType() != expectedRankReducedResultType) {
1389  return emitOpError("result type "
1390  "mismatch: "
1391  "expected ")
1392  << expectedResultType << " or its rank-reduced variant "
1393  << expectedRankReducedResultType << " (got: " << getResultType()
1394  << ")";
1395  }
1396 
1397  return success();
1398 }
1399 
1400 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1401  if (OpFoldResult reshapedSource = reshapeConstantSource(
1402  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1403  getResult().getType()))
1404  return reshapedSource;
1405  return {};
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // InsertOp
1410 //===----------------------------------------------------------------------===//
1411 
1412 void InsertOp::getAsmResultNames(
1413  function_ref<void(Value, StringRef)> setNameFn) {
1414  setNameFn(getResult(), "inserted");
1415 }
1416 
1417 LogicalResult InsertOp::verify() {
1418  // Verify the # indices match if we have a ranked type.
1419  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1420  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1421  return emitOpError("incorrect number of indices");
1422  return success();
1423 }
1424 
1425 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1426  Attribute scalar = adaptor.getScalar();
1427  Attribute dest = adaptor.getDest();
1428  if (scalar && dest)
1429  if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1430  if (scalar == splatDest.getSplatValue<Attribute>())
1431  return dest;
1432  return {};
1433 }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // GenerateOp
1437 //===----------------------------------------------------------------------===//
1438 
1439 void GenerateOp::getAsmResultNames(
1440  function_ref<void(Value, StringRef)> setNameFn) {
1441  setNameFn(getResult(), "generated");
1442 }
1443 
1444 LogicalResult GenerateOp::reifyResultShapes(
1445  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1446  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1447  int idx = 0;
1448  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1449  if (getType().isDynamicDim(dim)) {
1450  reifiedReturnShapes[0][dim] = getOperand(idx++);
1451  } else {
1452  reifiedReturnShapes[0][dim] =
1453  builder.getIndexAttr(getType().getDimSize(dim));
1454  }
1455  }
1456  return success();
1457 }
1458 
1459 LogicalResult GenerateOp::verify() {
1460  // Ensure that the tensor type has as many dynamic dimensions as are
1461  // specified by the operands.
1462  RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1463  if (getNumOperands() != resultType.getNumDynamicDims())
1464  return emitError("must have as many index operands as dynamic extents "
1465  "in the result type");
1466  return success();
1467 }
1468 
1469 LogicalResult GenerateOp::verifyRegions() {
1470  RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1471  // Ensure that region arguments span the index space.
1472  if (!llvm::all_of(getBody().getArgumentTypes(),
1473  [](Type ty) { return ty.isIndex(); }))
1474  return emitError("all body arguments must be index");
1475  if (getBody().getNumArguments() != resultTy.getRank())
1476  return emitError("must have one body argument per input dimension");
1477 
1478  // Ensure that the region yields an element of the right type.
1479  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1480 
1481  if (yieldOp.getValue().getType() != resultTy.getElementType())
1482  return emitOpError(
1483  "body must be terminated with a `yield` operation of the tensor "
1484  "element type");
1485 
1486  return success();
1487 }
1488 
1489 void GenerateOp::build(
1490  OpBuilder &b, OperationState &result, Type resultTy,
1491  ValueRange dynamicExtents,
1492  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1493  build(b, result, resultTy, dynamicExtents);
1494 
1495  // Build and populate body.
1496  OpBuilder::InsertionGuard guard(b);
1497  Region *bodyRegion = result.regions.front().get();
1498  auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1499  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1500  SmallVector<Location, 2> argumentLocs(rank, result.location);
1501  Block *bodyBlock =
1502  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1503  bodyBuilder(b, result.location, bodyBlock->getArguments());
1504 }
1505 
1506 namespace {
1507 
1508 /// Canonicalizes tensor.generate operations with a constant
1509 /// operand into the equivalent operation with the operand expressed in the
1510 /// result type, instead. We also insert a type cast to make sure that the
1511 /// resulting IR is still well-typed.
1512 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1514 
1515  LogicalResult matchAndRewrite(GenerateOp generateOp,
1516  PatternRewriter &rewriter) const final {
1517  SmallVector<Value> foldedDynamicSizes;
1518  RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1519  generateOp.getType(), generateOp.getDynamicExtents(),
1520  foldedDynamicSizes);
1521 
1522  // Stop here if no dynamic size was promoted to static.
1523  if (foldedTensorType == generateOp.getType())
1524  return failure();
1525 
1526  auto loc = generateOp.getLoc();
1527  auto newOp =
1528  rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1529  rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1530  newOp.getBody().begin());
1531  rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1532  generateOp.getType(), newOp);
1533  return success();
1534  }
1535 };
1536 
1537 /// Canonicalizes the pattern of the form
1538 ///
1539 /// %tensor = tensor.generate %x {
1540 /// ^bb0(%arg0: index):
1541 /// <computation>
1542 /// yield %1 : index
1543 /// } : tensor<?xindex>
1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1545 ///
1546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
1547 /// tensor.generate operation has no side-effects.
1548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1550 
1551  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1552  PatternRewriter &rewriter) const final {
1553  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1554  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1555  return failure();
1556 
1557  IRMapping mapping;
1558  Block *body = &tensorFromElements.getBody().front();
1559  mapping.map(body->getArguments(), extract.getIndices());
1560  for (auto &op : body->without_terminator())
1561  rewriter.clone(op, mapping);
1562 
1563  auto yield = cast<YieldOp>(body->getTerminator());
1564 
1565  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1566  return success();
1567  }
1568 };
1569 
1570 } // namespace
1571 
1572 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1573  MLIRContext *context) {
1574  // TODO: Move extract pattern to tensor::ExtractOp.
1575  results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // RankOp
1580 //===----------------------------------------------------------------------===//
1581 
1582 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1583  setNameFn(getResult(), "rank");
1584 }
1585 
1586 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1587  // Constant fold rank when the rank of the operand is known.
1588  auto type = getOperand().getType();
1589  auto shapedType = llvm::dyn_cast<ShapedType>(type);
1590  if (shapedType && shapedType.hasRank())
1591  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1592  return IntegerAttr();
1593 }
1594 
1595 //===----------------------------------------------------------------------===//
1596 // ReshapeOp
1597 //===----------------------------------------------------------------------===//
1598 
1599 void ReshapeOp::getAsmResultNames(
1600  function_ref<void(Value, StringRef)> setNameFn) {
1601  setNameFn(getResult(), "reshape");
1602 }
1603 
1604 static int64_t getNumElements(ShapedType type) {
1605  int64_t numElements = 1;
1606  for (auto dim : type.getShape())
1607  numElements *= dim;
1608  return numElements;
1609 }
1610 
1611 LogicalResult ReshapeOp::verify() {
1612  TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1613  TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1614 
1615  if (operandType.getElementType() != resultType.getElementType())
1616  return emitOpError("element types of source and destination tensor "
1617  "types should be the same");
1618 
1619  int64_t shapeSize =
1620  llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1621  auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1622  auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1623 
1624  if (resultRankedType) {
1625  if (operandRankedType && resultRankedType.hasStaticShape() &&
1626  operandRankedType.hasStaticShape()) {
1627  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1628  return emitOpError("source and destination tensor should have the "
1629  "same number of elements");
1630  }
1631  if (ShapedType::isDynamic(shapeSize))
1632  return emitOpError("cannot use shape operand with dynamic length to "
1633  "reshape to statically-ranked tensor type");
1634  if (shapeSize != resultRankedType.getRank())
1635  return emitOpError(
1636  "length of shape operand differs from the result's tensor rank");
1637  }
1638  return success();
1639 }
1640 
1641 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1642  if (OpFoldResult reshapedSource = reshapeConstantSource(
1643  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1644  getResult().getType()))
1645  return reshapedSource;
1646 
1647  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1648  // producer's input instead as the original tensor to reshape. This could
1649  // render such producer dead code.
1650  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1651  getSourceMutable().assign(reshapeOpProducer.getSource());
1652  return getResult();
1653  }
1654 
1655  auto source = getSource();
1656  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1657  auto resultTy = dyn_cast<RankedTensorType>(getType());
1658  if (!sourceTy || !resultTy || sourceTy != resultTy)
1659  return {};
1660 
1661  // If the source and result are both 1D tensors and have the same type, the
1662  // reshape has no effect, even if the tensor is dynamically shaped.
1663  if (sourceTy.getRank() == 1)
1664  return source;
1665 
1666  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1667  auto elements = fromElements.getElements();
1668  bool dynamicNoop =
1669  sourceTy.getRank() == static_cast<int64_t>(elements.size());
1670  for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1671  auto element = elements[id];
1672 
1673  if (auto cst = getConstantIntValue(element)) {
1674  dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1675  continue;
1676  }
1677 
1678  if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1679  dynamicNoop &= dimOp.getSource() == source;
1680 
1681  APSInt dim;
1682  auto cst = getConstantIntValue(dimOp.getIndex());
1683  dynamicNoop &=
1684  cst.has_value() && cst.value() == static_cast<int64_t>(id);
1685  continue;
1686  }
1687 
1688  dynamicNoop = false;
1689  break;
1690  }
1691 
1692  if (dynamicNoop)
1693  return source;
1694  }
1695 
1696  return {};
1697 }
1698 
1699 //===----------------------------------------------------------------------===//
1700 // Reassociative reshape ops
1701 //===----------------------------------------------------------------------===//
1702 
1703 void CollapseShapeOp::getAsmResultNames(
1704  function_ref<void(Value, StringRef)> setNameFn) {
1705  setNameFn(getResult(), "collapsed");
1706 }
1707 
1708 void ExpandShapeOp::getAsmResultNames(
1709  function_ref<void(Value, StringRef)> setNameFn) {
1710  setNameFn(getResult(), "expanded");
1711 }
1712 
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1714  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1715  "invalid resultDim");
1716  for (const auto &it : llvm::enumerate(getReassociationIndices()))
1717  if (llvm::is_contained(it.value(), resultDim))
1718  return it.index();
1719  llvm_unreachable("could not find reassociation group");
1720 }
1721 
1722 FailureOr<SmallVector<OpFoldResult>>
1723 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1724  RankedTensorType expandedType,
1725  ArrayRef<ReassociationIndices> reassociation,
1726  ArrayRef<OpFoldResult> inputShape) {
1727  std::optional<SmallVector<OpFoldResult>> outputShape =
1728  inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1729  inputShape);
1730  if (!outputShape)
1731  return failure();
1732  return *outputShape;
1733 }
1734 
1735 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1736  Type resultType, Value src,
1737  ArrayRef<ReassociationIndices> reassociation,
1738  ArrayRef<OpFoldResult> outputShape) {
1739  auto [staticOutputShape, dynamicOutputShape] =
1741  build(builder, result, cast<RankedTensorType>(resultType), src,
1742  getReassociationIndicesAttribute(builder, reassociation),
1743  dynamicOutputShape, staticOutputShape);
1744 }
1745 
1746 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1747  Type resultType, Value src,
1748  ArrayRef<ReassociationIndices> reassociation) {
1749  SmallVector<OpFoldResult> inputShape =
1750  getMixedSizes(builder, result.location, src);
1751  auto tensorResultTy = cast<RankedTensorType>(resultType);
1752  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1753  builder, result.location, tensorResultTy, reassociation, inputShape);
1754  SmallVector<OpFoldResult> outputShapeOrEmpty;
1755  if (succeeded(outputShape)) {
1756  outputShapeOrEmpty = *outputShape;
1757  }
1758  build(builder, result, tensorResultTy, src, reassociation,
1759  outputShapeOrEmpty);
1760 }
1761 
1762 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1763  return getSymbolLessAffineMaps(getReassociationExprs());
1764 }
1765 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1767  getReassociationIndices());
1768 }
1769 
1770 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1771  return getSymbolLessAffineMaps(getReassociationExprs());
1772 }
1773 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1775  getReassociationIndices());
1776 }
1777 
1778 RankedTensorType CollapseShapeOp::inferCollapsedType(
1779  RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1780  return inferCollapsedType(
1782  type.getContext(), reassociation)));
1783 }
1784 
1785 /// Compute the RankedTensorType obtained by applying `reassociation` to
1786 /// `type`.
1787 RankedTensorType
1788 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1789  ArrayRef<AffineMap> reassociation) {
1790  auto shape = type.getShape();
1791  SmallVector<int64_t, 4> newShape;
1792  newShape.reserve(reassociation.size());
1793 
1794  // Use the fact that reassociation is valid to simplify the logic: only use
1795  // each map's rank.
1796  assert(isReassociationValid(reassociation) && "invalid reassociation");
1797  unsigned currentDim = 0;
1798  for (AffineMap m : reassociation) {
1799  unsigned dim = m.getNumResults();
1800  auto band = shape.slice(currentDim, dim);
1801  int64_t size = 1;
1802  if (llvm::is_contained(band, ShapedType::kDynamic))
1803  size = ShapedType::kDynamic;
1804  else
1805  for (unsigned d = 0; d < dim; ++d)
1806  size *= shape[currentDim + d];
1807  newShape.push_back(size);
1808  currentDim += dim;
1809  }
1810 
1811  return RankedTensorType::get(newShape, type.getElementType());
1812 }
1813 
1814 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1815  ArrayRef<ReassociationIndices> reassociation,
1816  ArrayRef<NamedAttribute> attrs) {
1817  auto resultType = inferCollapsedType(
1818  llvm::cast<RankedTensorType>(src.getType()),
1820  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
1821  result.addAttribute(getReassociationAttrStrName(),
1822  getReassociationIndicesAttribute(b, reassociation));
1823  build(b, result, resultType, src, attrs);
1824 }
1825 
1826 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
1827  TensorReshapeOp, ExpandShapeOp>::value>
1828 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
1829  RankedTensorType expandedType,
1830  RankedTensorType collapsedType) {
1831  if (failed(
1832  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1833  return failure();
1834 
1835  auto maps = op.getReassociationMaps();
1836  RankedTensorType expectedType =
1837  CollapseShapeOp::inferCollapsedType(expandedType, maps);
1838  if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
1839  return op.emitOpError("expected collapsed type to be ")
1840  << expectedType << ", but got " << collapsedType;
1841  return success();
1842 }
1843 
1844 LogicalResult ExpandShapeOp::verify() {
1845  auto srcType = getSrcType();
1846  auto resultType = getResultType();
1847 
1848  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1849  return emitOpError("expected number of static shape dims to be equal to "
1850  "the output rank (")
1851  << resultType.getRank() << ") but found "
1852  << getStaticOutputShape().size() << " inputs instead";
1853 
1854  if ((int64_t)getOutputShape().size() !=
1855  llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1856  return emitOpError("mismatch in dynamic dims in output_shape and "
1857  "static_output_shape: static_output_shape has ")
1858  << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1859  << " dynamic dims while output_shape has " << getOutputShape().size()
1860  << " values";
1861 
1862  return verifyTensorReshapeOp(*this, resultType, srcType);
1863 }
1864 
1865 LogicalResult CollapseShapeOp::verify() {
1866  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
1867 }
1868 
1869 namespace {
1870 /// Reshape of a splat constant can be replaced with a constant of the result
1871 /// type.
1872 template <typename TensorReshapeOp>
1873 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
1875  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1876  PatternRewriter &rewriter) const override {
1877  DenseElementsAttr attr;
1878  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
1879  return failure();
1880  if (!attr || !attr.isSplat())
1881  return failure();
1883  reshapeOp.getResultType(), attr.getRawData());
1884  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
1885  return success();
1886  }
1887 };
1888 
1889 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
1890 template <typename TensorReshapeOp>
1891 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
1892 public:
1894 
1895  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1896  PatternRewriter &rewriter) const override {
1897  auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1898  if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1899  return failure();
1900 
1901  rewriter.replaceOpWithNewOp<tensor::SplatOp>(
1902  reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1903  return success();
1904  }
1905 };
1906 
1907 /// Reshape of a FromElements can be replaced with a FromElements of the
1908 /// result type
1909 template <typename TensorReshapeOp>
1910 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
1912  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1913  PatternRewriter &rewriter) const override {
1914  auto fromElements =
1915  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1916  if (!fromElements)
1917  return failure();
1918 
1919  auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1920 
1921  if (!shapedTy.hasStaticShape())
1922  return failure();
1923 
1924  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
1925  fromElements.getElements());
1926  return success();
1927  }
1928 };
1929 
1930 // Fold CastOp into CollapseShapeOp when adding static information.
1931 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
1933 
1934  LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1935  PatternRewriter &rewriter) const override {
1936  auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1937  if (!tensor::canFoldIntoConsumerOp(castOp))
1938  return failure();
1939 
1940  RankedTensorType srcType =
1941  llvm::cast<RankedTensorType>(castOp.getSource().getType());
1942  RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1943  srcType, collapseShapeOp.getReassociationMaps());
1944 
1945  if (newResultType == collapseShapeOp.getResultType()) {
1946  rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
1947  collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1948  });
1949  } else {
1950  auto newOp = rewriter.create<CollapseShapeOp>(
1951  collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1952  collapseShapeOp.getReassociation());
1953  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1954  collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1955  }
1956  return success();
1957  }
1958 };
1959 
1960 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
1962 
1963  LogicalResult matchAndRewrite(DimOp dimOp,
1964  PatternRewriter &rewriter) const override {
1965  auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1966  if (!expandShapeOp)
1967  return failure();
1968 
1969  // Only constant dimension values are supported.
1970  std::optional<int64_t> dim = dimOp.getConstantIndex();
1971  if (!dim.has_value())
1972  return failure();
1973 
1974  // Skip static dims. These are folded to constant ops.
1975  RankedTensorType resultType = expandShapeOp.getResultType();
1976  if (!resultType.isDynamicDim(*dim))
1977  return failure();
1978 
1979  // Find reassociation group that contains this result dimension.
1980  int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1981 
1982  // `dim` is the only dynamic dimension in `group`. (Otherwise, the
1983  // ExpandShapeOp would be ambiguous.)
1984  int64_t product = 1;
1985  ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
1986  for (int64_t d : grp) {
1987  if (d != dim) {
1988  assert(!resultType.isDynamicDim(d) && "expected static dim");
1989  product *= resultType.getDimSize(d);
1990  }
1991  }
1992 
1993  // result dim size = src dim size / (product(other dims in reassoc group))
1994  Value srcDimSz =
1995  rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1996  AffineExpr expr;
1997  bindSymbols(dimOp.getContext(), expr);
1998  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
1999  dimOp, expr.floorDiv(product), srcDimSz);
2000  return success();
2001  }
2002 };
2003 
2004 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
2006 
2007  LogicalResult matchAndRewrite(DimOp dimOp,
2008  PatternRewriter &rewriter) const override {
2009  auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2010  if (!collapseShapeOp)
2011  return failure();
2012 
2013  // Only constant dimension values are supported.
2014  std::optional<int64_t> dim = dimOp.getConstantIndex();
2015  if (!dim.has_value())
2016  return failure();
2017 
2018  // Skip static dims. These are folded to constant ops.
2019  RankedTensorType resultType = collapseShapeOp.getResultType();
2020  if (!resultType.isDynamicDim(*dim))
2021  return failure();
2022 
2023  // Get reassociation group of the result dimension.
2024  ReassociationIndices group =
2025  collapseShapeOp.getReassociationIndices()[*dim];
2026 
2027  // result dim size = product(dims in reassoc group)
2028  SmallVector<Value> srcDimSizes;
2031  for (const auto &it : llvm::enumerate(group)) {
2032  srcDimSizes.push_back(rewriter.create<DimOp>(
2033  dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2034  syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
2035  product = product ? product * syms.back() : syms.back();
2036  }
2037  rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
2038  srcDimSizes);
2039  return success();
2040  }
2041 };
2042 
2043 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2044 /// matching constant output_shape operands of the expand. This makes the
2045 /// `tensor.expand_shape` more static and creates a consumer cast that can be
2046 /// propagated further.
2047 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2049 
2050  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2051  PatternRewriter &rewriter) const override {
2052  auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2053  if (!canFoldIntoConsumerOp(castOp))
2054  return failure();
2055 
2056  ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2058  expandOp.getReassociationIndices();
2059 
2060  SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2061  SmallVector<Value> dynamicOutputShape;
2062  auto outputIt = expandOp.getOutputShape().begin();
2063 
2064  for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2065  for (uint64_t outDim : innerReassoc) {
2066  if (!ShapedType::isDynamic(newOutputShape[outDim]))
2067  continue;
2068 
2069  // If the cast's src type is dynamic, don't infer any of the
2070  // corresponding expanded dimensions. `tensor.expand_shape` requires at
2071  // least one of the expanded dimensions to be dynamic if the input is
2072  // dynamic.
2073  Value val = *outputIt;
2074  ++outputIt;
2075  if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2076  dynamicOutputShape.push_back(val);
2077  continue;
2078  }
2079 
2080  APInt cst;
2081  if (matchPattern(val, m_ConstantInt(&cst))) {
2082  newOutputShape[outDim] = cst.getSExtValue();
2083  } else {
2084  dynamicOutputShape.push_back(val);
2085  }
2086  }
2087  }
2088 
2089  // Couldn't match any values, nothing to change
2090  if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2091  return failure();
2092 
2093  // Calculate the input shape from the output
2094  SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2095  for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2096  for (auto outDim : reassoc[inDim]) {
2097  auto ofr = newOutputShape[outDim];
2098  if (ShapedType::isDynamic(ofr)) {
2099  newInputShape[inDim] = ShapedType::kDynamic;
2100  break;
2101  }
2102  newInputShape[inDim] *= ofr;
2103  }
2104  }
2105 
2106  SmallVector<OpFoldResult> outputOfr =
2107  getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2108  auto inputType = RankedTensorType::get(
2109  newInputShape, expandOp.getSrcType().getElementType());
2110  auto outputType = RankedTensorType::get(
2111  newOutputShape, expandOp.getSrcType().getElementType());
2112  auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2113  expandOp.getSrc());
2114  auto newExpand = rewriter.create<ExpandShapeOp>(
2115  expandOp.getLoc(), outputType, inputCast.getResult(),
2116  expandOp.getReassociationIndices(), outputOfr);
2117  rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2118  newExpand.getResult());
2119  return success();
2120  }
2121 };
2122 } // namespace
2123 
2124 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2125  MLIRContext *context) {
2126  results.add<
2129  ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2130  FoldReshapeWithSplat<ExpandShapeOp>,
2131  FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2132  FoldDimOfCollapseShape>(context);
2133 }
2134 
2135 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2136  MLIRContext *context) {
2137  results.add<
2139  ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2140  tensor::DimOp, RankedTensorType>,
2141  FoldReshapeWithConstant<CollapseShapeOp>,
2142  FoldReshapeWithSplat<CollapseShapeOp>,
2143  FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2144  context);
2145 }
2146 
2147 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2148  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2149  adaptor.getOperands());
2150 }
2151 
2152 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2153  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2154  adaptor.getOperands());
2155 }
2156 
2157 //===----------------------------------------------------------------------===//
2158 // ExtractSliceOp
2159 //===----------------------------------------------------------------------===//
2160 
2161 void ExtractSliceOp::getAsmResultNames(
2162  function_ref<void(Value, StringRef)> setNameFn) {
2163  setNameFn(getResult(), "extracted_slice");
2164 }
2165 
2166 /// An extract_slice result type can be inferred, when it is not
2167 /// rank-reduced, from the source type and the static representation of
2168 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
2169 RankedTensorType ExtractSliceOp::inferResultType(
2170  RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2171  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2172  // An extract_slice op may specify only a leading subset of offset/sizes/
2173  // strides in which case we complete with offset=0, sizes from memref type
2174  // and strides=1.
2175  assert(static_cast<int64_t>(staticSizes.size()) ==
2176  sourceTensorType.getRank() &&
2177  "unexpected staticSizes not equal to rank of source");
2178  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2179  sourceTensorType.getEncoding());
2180 }
2181 
2182 RankedTensorType ExtractSliceOp::inferResultType(
2183  RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2185  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2186  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2187  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2188  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2189  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2190  return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2191  staticSizes, staticStrides);
2192 }
2193 
2194 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2195 /// number of sizes), drop as many size 1 as needed to produce an inferred
2196 /// type with the desired rank.
2197 ///
2198 /// Note that there may be multiple ways to compute this rank-reduced type:
2199 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2200 ///
2201 /// To disambiguate, this function always drops the first 1 sizes occurrences.
2202 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2203  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2204  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2205  ArrayRef<int64_t> strides) {
2206  // Type inferred in the absence of rank-reducing behavior.
2207  auto inferredType = llvm::cast<RankedTensorType>(
2208  inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2209  int rankDiff = inferredType.getRank() - desiredResultRank;
2210  if (rankDiff > 0) {
2211  auto shape = inferredType.getShape();
2212  llvm::SmallBitVector dimsToProject =
2213  getPositionsOfShapeOne(rankDiff, shape);
2214  SmallVector<int64_t> projectedShape;
2215  // Best effort rank-reducing: drop 1s in order.
2216  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2217  if (!dimsToProject.test(pos))
2218  projectedShape.push_back(shape[pos]);
2219  inferredType =
2220  RankedTensorType::get(projectedShape, inferredType.getElementType());
2221  }
2222  return inferredType;
2223 }
2224 
2225 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2226  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2228  ArrayRef<OpFoldResult> strides) {
2229  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2230  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2231  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2232  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2233  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2234  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2235  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2236  staticStrides);
2237 }
2238 
2239 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2240 /// result type. If the type passed is nullptr, it is inferred.
2241 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2242  RankedTensorType resultType, Value source,
2243  ArrayRef<OpFoldResult> offsets,
2244  ArrayRef<OpFoldResult> sizes,
2245  ArrayRef<OpFoldResult> strides,
2246  ArrayRef<NamedAttribute> attrs) {
2247  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2248  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2249  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2250  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2251  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2252  auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2253  // Structuring implementation this way avoids duplication between builders.
2254  if (!resultType) {
2255  resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2256  sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2257  }
2258  result.addAttributes(attrs);
2259  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2260  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2261  b.getDenseI64ArrayAttr(staticSizes),
2262  b.getDenseI64ArrayAttr(staticStrides));
2263 }
2264 
2265 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2266 /// result type.
2267 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2268  ArrayRef<OpFoldResult> offsets,
2269  ArrayRef<OpFoldResult> sizes,
2270  ArrayRef<OpFoldResult> strides,
2271  ArrayRef<NamedAttribute> attrs) {
2272  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2273 }
2274 
2275 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2276 /// a Range vector.
2277 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2278  ArrayRef<Range> ranges,
2279  ArrayRef<NamedAttribute> attrs) {
2280  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2281  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2282 }
2283 
2284 /// Build an ExtractSliceOp with dynamic entries and custom result type. If
2285 /// the type passed is nullptr, it is inferred.
2286 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2287  RankedTensorType resultType, Value source,
2288  ValueRange offsets, ValueRange sizes,
2289  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2290  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2291  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2292  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2293  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2294  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2295  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2296  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2297 }
2298 
2299 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
2300 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2301  ValueRange offsets, ValueRange sizes,
2302  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2303  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2304 }
2305 
2306 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2307  Operation *op,
2308  RankedTensorType expectedType) {
2309  switch (result) {
2311  return success();
2313  return op->emitError("expected rank to be smaller or equal to ")
2314  << "the other rank. ";
2316  return op->emitError("expected type to be ")
2317  << expectedType << " or a rank-reduced version. (size mismatch) ";
2319  return op->emitError("expected element type to be ")
2320  << expectedType.getElementType();
2321  default:
2322  llvm_unreachable("unexpected extract_slice op verification result");
2323  }
2324 }
2325 
2326 /// Verifier for ExtractSliceOp.
2327 LogicalResult ExtractSliceOp::verify() {
2328  // Verify result type against inferred type.
2329  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2330  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2331  SliceVerificationResult result = isRankReducedType(expectedType, getType());
2332  return produceSliceErrorMsg(result, *this, expectedType);
2333 }
2334 
2335 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2337 }
2338 
2339 FailureOr<Value>
2340 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2341  ArrayRef<int64_t> desiredShape) {
2342  auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2343  assert(sourceTensorType && "not a ranked tensor type");
2344  auto sourceShape = sourceTensorType.getShape();
2345  if (sourceShape.equals(desiredShape))
2346  return value;
2347  auto maybeRankReductionMask =
2348  mlir::computeRankReductionMask(sourceShape, desiredShape);
2349  if (!maybeRankReductionMask)
2350  return failure();
2352  b, loc, value,
2353  RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2354 }
2355 
2356 LogicalResult ExtractSliceOp::reifyResultShapes(
2357  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2358  reifiedReturnShapes.resize(1);
2359  reifiedReturnShapes[0].reserve(getType().getRank());
2361  llvm::SmallBitVector droppedDims = getDroppedDims();
2362  for (const auto &size : enumerate(mixedSizes)) {
2363  if (droppedDims.test(size.index()))
2364  continue;
2365  reifiedReturnShapes[0].push_back(size.value());
2366  }
2367  return success();
2368 }
2369 
2370 namespace {
2371 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2372 /// This essentially pushes memref_cast past its consuming slice when
2373 /// `canFoldIntoConsumerOp` is true.
2374 ///
2375 /// Example:
2376 /// ```
2377 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2378 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2379 /// tensor<3x4xf32>
2380 /// ```
2381 /// is rewritten into:
2382 /// ```
2383 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2384 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2385 /// ```
2386 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2387 public:
2389 
2390  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2391  PatternRewriter &rewriter) const override {
2392  // Any constant operand, just return to let the constant folder kick in.
2393  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2394  return matchPattern(operand, matchConstantIndex());
2395  }))
2396  return failure();
2397 
2398  auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2399  if (!castOp)
2400  return failure();
2401 
2402  if (!canFoldIntoConsumerOp(castOp))
2403  return failure();
2404 
2405  // Create folded extract.
2406  Location loc = sliceOp.getLoc();
2407  Value newResult = rewriter.create<ExtractSliceOp>(
2408  loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2409  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2410  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2411  if (newResult.getType() != sliceOp.getType())
2412  newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
2413  rewriter.replaceOp(sliceOp, newResult);
2414  return success();
2415  }
2416 };
2417 
2418 /// Slice elements from `values` into `outValues`. `counts` represents the
2419 /// numbers of elements to stride in the original values for each dimension.
2420 /// The output values can be used to construct a DenseElementsAttr.
2421 template <typename IterTy, typename ElemTy>
2422 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2423  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2424  ArrayRef<int64_t> strides,
2425  llvm::SmallVectorImpl<ElemTy> *outValues) {
2426  assert(offsets.size() == sizes.size());
2427  assert(offsets.size() == strides.size());
2428  if (offsets.empty())
2429  return;
2430 
2431  int64_t offset = offsets.front();
2432  int64_t size = sizes.front();
2433  int64_t stride = strides.front();
2434  if (offsets.size() == 1) {
2435  for (int64_t i = 0; i < size; ++i, offset += stride)
2436  outValues->push_back(*(values + offset));
2437 
2438  return;
2439  }
2440 
2441  for (int64_t i = 0; i < size; ++i, offset += stride) {
2442  auto begin = values + offset * counts.front();
2443  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2444  offsets.drop_front(), sizes.drop_front(),
2445  strides.drop_front(), outValues);
2446  }
2447 }
2448 
2449 /// Fold arith.constant and tensor.extract_slice into arith.constant. The
2450 /// folded operation might introduce more constant data; Users can control
2451 /// their heuristics by the control function.
2452 class ConstantOpExtractSliceFolder final
2453  : public OpRewritePattern<ExtractSliceOp> {
2454 public:
2456 
2457  ConstantOpExtractSliceFolder(MLIRContext *context,
2459  : OpRewritePattern<ExtractSliceOp>(context),
2460  controlFn(std::move(controlFn)) {}
2461 
2462  LogicalResult matchAndRewrite(ExtractSliceOp op,
2463  PatternRewriter &rewriter) const override {
2464  DenseElementsAttr attr;
2465  if (!matchPattern(op.getSource(), m_Constant(&attr)))
2466  return failure();
2467 
2468  // A constant splat is handled by fold().
2469  if (attr.isSplat())
2470  return failure();
2471 
2472  // Dynamic result shape is not supported.
2473  auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2474  auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2475  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2476  return failure();
2477 
2478  // Customized control over the folding.
2479  if (!controlFn(op))
2480  return failure();
2481 
2482  int64_t count = sourceType.getNumElements();
2483  if (count == 0)
2484  return failure();
2485 
2486  // Check if there are any dynamic parts, which are not supported.
2487  auto offsets = op.getStaticOffsets();
2488  if (llvm::is_contained(offsets, ShapedType::kDynamic))
2489  return failure();
2490  auto sizes = op.getStaticSizes();
2491  if (llvm::is_contained(sizes, ShapedType::kDynamic))
2492  return failure();
2493  auto strides = op.getStaticStrides();
2494  if (llvm::is_contained(strides, ShapedType::kDynamic))
2495  return failure();
2496 
2497  // Compute the stride for each dimension.
2498  SmallVector<int64_t> counts;
2499  ArrayRef<int64_t> shape = sourceType.getShape();
2500  counts.reserve(shape.size());
2501  for (int64_t v : shape) {
2502  count = count / v;
2503  counts.push_back(count);
2504  }
2505 
2506  // New attribute constructed by the sliced values.
2507  DenseElementsAttr newAttr;
2508 
2509  if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2510  SmallVector<APInt> outValues;
2511  outValues.reserve(sourceType.getNumElements());
2512  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2513  elems.begin(), counts, offsets, sizes, strides, &outValues);
2514  newAttr = DenseElementsAttr::get(resultType, outValues);
2515  } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2516  SmallVector<APFloat> outValues;
2517  outValues.reserve(sourceType.getNumElements());
2518  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2519  elems.begin(), counts, offsets, sizes, strides, &outValues);
2520  newAttr = DenseElementsAttr::get(resultType, outValues);
2521  }
2522 
2523  if (newAttr) {
2524  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2525  return success();
2526  }
2527 
2528  return failure();
2529  }
2530 
2531 private:
2532  /// This additionally controls whether the fold happens or not. Users can
2533  /// impose their heuristics in the function.
2535 };
2536 
2537 } // namespace
2538 
2540  RewritePatternSet &patterns,
2541  const ControlConstantExtractSliceFusionFn &controlFn) {
2542  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2543 }
2544 
2545 /// Return the canonical type of the result of an extract_slice op.
2547  RankedTensorType operator()(ExtractSliceOp op,
2548  ArrayRef<OpFoldResult> mixedOffsets,
2549  ArrayRef<OpFoldResult> mixedSizes,
2550  ArrayRef<OpFoldResult> mixedStrides) {
2551  return ExtractSliceOp::inferCanonicalRankReducedResultType(
2552  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2553  mixedStrides);
2554  }
2555 };
2556 
2557 /// A canonicalizer wrapper to replace ExtractSliceOps.
2559  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2560  ExtractSliceOp newOp) {
2561  Value replacement = newOp.getResult();
2562  if (replacement.getType() != op.getType())
2563  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2564  replacement);
2565  rewriter.replaceOp(op, replacement);
2566  }
2567 };
2568 
2569 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2570  MLIRContext *context) {
2571  results.add<
2574  ExtractSliceOpCastFolder>(context);
2575 }
2576 
2577 //
2578 static LogicalResult
2579 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2580  ShapedType shapedType) {
2581  OpBuilder b(op.getContext());
2582  for (OpFoldResult ofr : op.getMixedOffsets())
2583  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2584  return failure();
2585  // Rank-reducing noops only need to inspect the leading dimensions:
2586  // llvm::zip is appropriate.
2587  auto shape = shapedType.getShape();
2588  for (auto it : llvm::zip(op.getMixedSizes(), shape))
2589  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2590  return failure();
2591  for (OpFoldResult ofr : op.getMixedStrides())
2592  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2593  return failure();
2594  return success();
2595 }
2596 
2597 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2598 /// slice, we can return the InsertSliceOp's source directly.
2599 // TODO: This only checks the immediate producer; extend to go up the
2600 // insert/extract chain if the slices are disjoint.
2601 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2602  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2603 
2604  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2605  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2606  insertOp.isSameAs(extractOp, isSame))
2607  return insertOp.getSource();
2608 
2609  return {};
2610 }
2611 
2612 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2613  if (OpFoldResult reshapedSource = reshapeConstantSource(
2614  llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2615  getResult().getType()))
2616  return reshapedSource;
2617  if (getSourceType() == getType() &&
2619  return this->getSource();
2620  if (Value slice = foldExtractAfterInsertSlice(*this))
2621  return slice;
2622 
2623  return OpFoldResult();
2624 }
2625 
2627  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2628  auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2629  unsigned rank = rankedTensorType.getRank();
2630  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2631  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
2632  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2633  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2634  offsets, sizes, strides);
2635 }
2636 
2637 //===----------------------------------------------------------------------===//
2638 // InsertSliceOp
2639 //===----------------------------------------------------------------------===//
2640 
2641 void InsertSliceOp::getAsmResultNames(
2642  function_ref<void(Value, StringRef)> setNameFn) {
2643  setNameFn(getResult(), "inserted_slice");
2644 }
2645 
2646 // Build a InsertSliceOp with mixed static and dynamic entries.
2647 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2648  Value dest, ArrayRef<OpFoldResult> offsets,
2649  ArrayRef<OpFoldResult> sizes,
2650  ArrayRef<OpFoldResult> strides,
2651  ArrayRef<NamedAttribute> attrs) {
2652  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2653  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2654  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2655  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2656  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2657  result.addAttributes(attrs);
2658  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2659  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2660  b.getDenseI64ArrayAttr(staticSizes),
2661  b.getDenseI64ArrayAttr(staticStrides));
2662 }
2663 
2664 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2665 /// Range vector.
2666 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2667  Value dest, ArrayRef<Range> ranges,
2668  ArrayRef<NamedAttribute> attrs) {
2669  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2670  build(b, result, source, dest, offsets, sizes, strides, attrs);
2671 }
2672 
2673 // Build a InsertSliceOp with dynamic entries.
2674 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2675  Value dest, ValueRange offsets, ValueRange sizes,
2676  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2677  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2678  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2679  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2680  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2681  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2682  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2683  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2684 }
2685 
2686 /// Rank-reducing type verification for both InsertSliceOp and
2687 /// ParallelInsertSliceOp.
2689  RankedTensorType srcType, RankedTensorType dstType,
2690  ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2691  ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2692  // insert_slice is the inverse of extract_slice, use the same type
2693  // inference.
2694  RankedTensorType expected = ExtractSliceOp::inferResultType(
2695  dstType, staticOffsets, staticSizes, staticStrides);
2696  if (expectedType)
2697  *expectedType = expected;
2698  return isRankReducedType(expected, srcType);
2699 }
2700 
2701 /// Verifier for InsertSliceOp.
2702 LogicalResult InsertSliceOp::verify() {
2703  RankedTensorType expectedType;
2704  SliceVerificationResult result =
2705  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2706  getStaticSizes(), getStaticStrides(), &expectedType);
2707  return produceSliceErrorMsg(result, *this, expectedType);
2708 }
2709 
2710 /// If we have two consecutive InsertSliceOp writing to the same slice, we
2711 /// can mutate the second InsertSliceOp's destination to the first one's.
2712 ///
2713 /// Example:
2714 ///
2715 /// ```mlir
2716 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2717 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2718 /// ```
2719 ///
2720 /// folds into:
2721 ///
2722 /// ```mlir
2723 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2724 /// ```
2725 ///
2726 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2727 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2728  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2729 
2730  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2731  if (!prevInsertOp ||
2732  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2733  !prevInsertOp.isSameAs(insertOp, isSame))
2734  return failure();
2735 
2736  insertOp.getDestMutable().assign(prevInsertOp.getDest());
2737  return success();
2738 }
2739 
2740 /// Folds round-trip extract/insert slice op pairs.
2741 /// Example:
2742 /// ```mlir
2743 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2744 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2745 /// ```
2746 /// can be folded into %val.
2747 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2748  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2749 
2750  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2751  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2752  !extractOp.isSameAs(insertOp, isSame))
2753  return nullptr;
2754 
2755  return extractOp.getSource();
2756 }
2757 
2758 OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2759  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2760  getSourceType() == getType() &&
2762  return this->getSource();
2763  if (succeeded(foldInsertAfterInsertSlice(*this)))
2764  return getResult();
2765  if (auto result = foldInsertAfterExtractSlice(*this))
2766  return result;
2767  if (llvm::any_of(getMixedSizes(),
2768  [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2769  return getDest();
2770  return OpFoldResult();
2771 }
2772 
2773 LogicalResult InsertSliceOp::reifyResultShapes(
2774  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2775  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2776  reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2777  return success();
2778 }
2779 
2780 namespace {
2781 /// Pattern to rewrite a insert_slice op with constant arguments.
2782 ///
2783 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2784 template <typename InsertOpTy>
2785 class InsertSliceOpConstantArgumentFolder final
2786  : public OpRewritePattern<InsertOpTy> {
2787 public:
2789 
2790  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2791  PatternRewriter &rewriter) const override {
2792  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2793  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2794  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2795 
2796  // No constant operands were folded, just return;
2797  if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2798  failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2799  failed(foldDynamicStrideList(mixedStrides)))
2800  return failure();
2801 
2802  // Create the new op in canonical form.
2803  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2804  insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2805  mixedOffsets, mixedSizes, mixedStrides);
2806  Value toInsert = insertSliceOp.getSource();
2807  if (sourceType != insertSliceOp.getSourceType()) {
2808  OpBuilder::InsertionGuard g(rewriter);
2809  // The only difference between InsertSliceOp and ParallelInsertSliceOp
2810  // is that the insertion point is just before the ParallelCombiningOp in
2811  // the parallel case.
2812  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2813  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2814  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2815  sourceType, toInsert);
2816  }
2817  rewriter.replaceOpWithNewOp<InsertOpTy>(
2818  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2819  mixedSizes, mixedStrides);
2820  return success();
2821  }
2822 };
2823 
2824 /// Fold tensor_casts with insert_slice operations. If the source or
2825 /// destination tensor is a tensor_cast that removes static type information,
2826 /// the cast is folded into the insert_slice operation. E.g.:
2827 ///
2828 /// ```mlir
2829 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2830 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2831 /// ```
2832 ///
2833 /// folds into:
2834 ///
2835 /// ```mlir
2836 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
2837 /// ```
2838 ///
2839 /// Note: When folding a cast on the destination tensor, the result of the
2840 /// insert_slice operation is casted to ensure that the type of the result did
2841 /// not change.
2842 ///
2843 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2844 template <typename InsertOpTy>
2845 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2847 
2848  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2849  PatternRewriter &rewriter) const override {
2850  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
2851  return matchPattern(operand, matchConstantIndex());
2852  }))
2853  return failure();
2854 
2855  auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
2856  auto castOp = v.getDefiningOp<tensor::CastOp>();
2857  if (!castOp || !canFoldIntoConsumerOp(castOp))
2858  return std::nullopt;
2859  return castOp.getSource();
2860  };
2861  std::optional<Value> sourceCastSource =
2862  getSourceOfCastOp(insertSliceOp.getSource());
2863  std::optional<Value> destCastSource =
2864  getSourceOfCastOp(insertSliceOp.getDest());
2865  if (!sourceCastSource && !destCastSource)
2866  return failure();
2867 
2868  auto src =
2869  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2870  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2871  auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
2872  auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2873  if (!srcType || !dstType)
2874  return failure();
2875 
2876  // The tensor.cast source could have additional static information not seen
2877  // in the insert slice op static sizes, so we ignore dynamic dims when
2878  // computing the rank reduction mask.
2879  SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
2880  auto rankReductionMask = computeRankReductionMask(
2881  staticSizes, srcType.getShape(), /*matchDynamic=*/true);
2882  if (!rankReductionMask.has_value())
2883  return failure();
2884  // Replace dimensions in the insert slice op with corresponding static dims
2885  // from the cast source type. If the insert slice sizes have static dims
2886  // that are not static in the tensor.cast source (i.e., when the cast op
2887  // casts a dynamic dim to static), the dim should not be replaced, and the
2888  // pattern will fail later in `verifyInsertSliceOp`.
2889  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2890  int64_t rankReducedIdx = 0;
2891  for (auto [idx, size] : enumerate(staticSizes)) {
2892  if (!rankReductionMask.value().contains(idx) &&
2893  !srcType.isDynamicDim(rankReducedIdx)) {
2894  mixedSizes[idx] = getAsIndexOpFoldResult(
2895  rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
2896  size = srcType.getDimSize(rankReducedIdx++);
2897  }
2898  }
2899  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
2900  staticSizes, insertSliceOp.getStaticStrides()) !=
2902  return failure();
2903 
2904  Operation *replacement = rewriter.create<InsertOpTy>(
2905  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2906  mixedSizes, insertSliceOp.getMixedStrides());
2907 
2908  // In the parallel case there is no result and so nothing to cast.
2909  bool isParallelInsert =
2910  std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2911  if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2912  replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2913  insertSliceOp.getDestType(),
2914  replacement->getResult(0));
2915  }
2916  rewriter.replaceOp(insertSliceOp, replacement->getResults());
2917  return success();
2918  }
2919 };
2920 
2921 /// If additional static type information can be deduced from a insert_slice's
2922 /// size operands, insert an explicit cast of the op's source operand. This
2923 /// enables other canonicalization patterns that are matching for tensor_cast
2924 /// ops such as `ForOpTensorCastFolder` in SCF.
2925 ///
2926 /// Example:
2927 ///
2928 /// ```mlir
2929 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
2930 /// : tensor<?x?xf32> into ...
2931 /// ```
2932 ///
2933 /// folds into:
2934 ///
2935 /// ```mlir
2936 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
2937 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
2938 /// : tensor<64x64xf32> into ...
2939 /// ```
2940 ///
2941 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
2942 template <typename InsertOpTy>
2943 struct InsertSliceOpSourceCastInserter final
2944  : public OpRewritePattern<InsertOpTy> {
2946 
2947  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2948  PatternRewriter &rewriter) const override {
2949  RankedTensorType srcType = insertSliceOp.getSourceType();
2950  if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2951  return failure();
2952  SmallVector<int64_t> newSrcShape(srcType.getShape());
2953  for (int64_t i = 0; i < srcType.getRank(); ++i) {
2954  if (std::optional<int64_t> constInt =
2955  getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2956  // Bail on invalid IR.
2957  if (*constInt < 0)
2958  return failure();
2959  newSrcShape[i] = *constInt;
2960  }
2961  }
2962  if (!hasValidSizesOffsets(newSrcShape))
2963  return failure();
2964 
2965  RankedTensorType newSrcType = RankedTensorType::get(
2966  newSrcShape, srcType.getElementType(), srcType.getEncoding());
2967  if (srcType == newSrcType ||
2968  !preservesStaticInformation(srcType, newSrcType) ||
2969  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2970  return failure();
2971 
2972  // newSrcType is:
2973  // 1) Different from srcType.
2974  // 2) "More static" than srcType.
2975  // 3) Cast-compatible with srcType.
2976  // Insert the cast.
2977  OpBuilder::InsertionGuard g(rewriter);
2978  // The only difference between InsertSliceOp and ParallelInsertSliceOp is
2979  // that the insertion point is just before the ParallelCombiningOp in the
2980  // parallel case.
2981  if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2982  rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2983  Value cast = rewriter.create<tensor::CastOp>(
2984  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2985  rewriter.replaceOpWithNewOp<InsertOpTy>(
2986  insertSliceOp, cast, insertSliceOp.getDest(),
2987  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2988  insertSliceOp.getMixedStrides());
2989  return success();
2990  }
2991 };
2992 } // namespace
2993 
2994 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
2995  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
2996 }
2997 
2998 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2999  MLIRContext *context) {
3000  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3001  InsertSliceOpCastFolder<InsertSliceOp>,
3002  InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3003 }
3004 
3006  Location loc,
3007  Value tensor,
3008  Value dest) {
3009  auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3010  unsigned rank = rankedTensorType.getRank();
3011  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3012  SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
3013  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3014  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3015  sizes, strides);
3016 }
3017 
3018 //===----------------------------------------------------------------------===//
3019 // PadOp
3020 //===----------------------------------------------------------------------===//
3021 
3022 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3023  setNameFn(getResult(), "padded");
3024 }
3025 
3026 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3027 // supports optional types.
3028 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3029  Type typeToInfer, Type typeToInferFrom) {}
3030 
3031 ParseResult
3033  std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3034  Type &typeToInfer, Type typeToInferFrom) {
3035  if (optOperand)
3036  typeToInfer = typeToInferFrom;
3037  return success();
3038 }
3039 
3040 LogicalResult PadOp::verify() {
3041  auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3042  auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3043  auto expectedType =
3044  PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3045  if (!expectedType) {
3046  return emitError("failed to infer expectedType from sourceType ")
3047  << sourceType << ", specified resultType is " << resultType;
3048  }
3049  if (resultType.getRank() != expectedType.getRank()) {
3050  return emitError("specified type ")
3051  << resultType << " does not match the inferred type "
3052  << expectedType;
3053  }
3054  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3055  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3056  continue;
3057  if (expectedType.isDynamicDim(i))
3058  continue;
3059  return emitError("specified type ")
3060  << resultType << " does not match the inferred type "
3061  << expectedType;
3062  }
3063 
3064  return success();
3065 }
3066 
3067 LogicalResult PadOp::verifyRegions() {
3068  auto &region = getRegion();
3069  unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3070  Block &block = region.front();
3071  if (block.getNumArguments() != rank)
3072  return emitError("expected the block to have ") << rank << " arguments";
3073 
3074  // Note: the number and type of yield values are checked in the YieldOp.
3075  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3076  if (!en.value().isIndex())
3077  return emitOpError("expected block argument ")
3078  << (en.index() + 1) << " to be an index";
3079  }
3080 
3081  // Ensure that the region yields an element of the right type.
3082  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3083  if (yieldOp.getValue().getType() !=
3084  llvm::cast<ShapedType>(getType()).getElementType())
3085  return emitOpError("expected yield type to match shape element type");
3086 
3087  return success();
3088 }
3089 
3090 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3091  ArrayRef<int64_t> staticLow,
3092  ArrayRef<int64_t> staticHigh,
3093  ArrayRef<int64_t> resultShape) {
3094  unsigned rank = sourceType.getRank();
3095  if (staticLow.size() != rank)
3096  return RankedTensorType();
3097  if (staticHigh.size() != rank)
3098  return RankedTensorType();
3099  if (!resultShape.empty() && resultShape.size() != rank)
3100  return RankedTensorType();
3101 
3102  SmallVector<int64_t, 4> inferredShape;
3103  for (auto i : llvm::seq<unsigned>(0, rank)) {
3104  if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3105  staticHigh[i] == ShapedType::kDynamic) {
3106  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3107  : resultShape[i]);
3108  } else {
3109  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3110  assert((resultShape.empty() || size == resultShape[i] ||
3111  resultShape[i] == ShapedType::kDynamic) &&
3112  "mismatch between inferred shape and result shape");
3113  inferredShape.push_back(size);
3114  }
3115  }
3116 
3117  return RankedTensorType::get(inferredShape, sourceType.getElementType());
3118 }
3119 
3120 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3121  Value source, ArrayRef<int64_t> staticLow,
3122  ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3123  bool nofold, ArrayRef<NamedAttribute> attrs) {
3124  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3125  if (!resultType)
3126  resultType = inferResultType(sourceType, staticLow, staticHigh);
3127  result.addAttributes(attrs);
3128  build(b, result, resultType, source, low, high,
3129  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3130  nofold ? b.getUnitAttr() : UnitAttr());
3131 }
3132 
3133 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3134  Value source, ValueRange low, ValueRange high, bool nofold,
3135  ArrayRef<NamedAttribute> attrs) {
3136  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3137  unsigned rank = sourceType.getRank();
3138  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3139  build(b, result, resultType, source, staticVector, staticVector, low, high,
3140  nofold, attrs);
3141 }
3142 
3143 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3144  Value source, ArrayRef<OpFoldResult> low,
3145  ArrayRef<OpFoldResult> high, bool nofold,
3146  ArrayRef<NamedAttribute> attrs) {
3147  auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3148  SmallVector<Value, 4> dynamicLow, dynamicHigh;
3149  SmallVector<int64_t, 4> staticLow, staticHigh;
3150  // staticLow and staticHigh have full information of the padding config.
3151  // This will grow staticLow and staticHigh with 1 value. If the config is
3152  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3153  // value as well.
3154  dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3155  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3156  if (!resultType) {
3157  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3158  }
3159  assert(llvm::isa<RankedTensorType>(resultType));
3160  result.addAttributes(attrs);
3161  build(b, result, resultType, source, dynamicLow, dynamicHigh,
3162  b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3163  nofold ? b.getUnitAttr() : UnitAttr());
3164 }
3165 
3166 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3167  Value source, ArrayRef<OpFoldResult> low,
3168  ArrayRef<OpFoldResult> high, Value constantPadValue,
3169  bool nofold, ArrayRef<NamedAttribute> attrs) {
3170  build(b, result, resultType, source, low, high, nofold, attrs);
3171 
3172  // Add a region and a block to yield the pad value.
3173  Region *region = result.regions[0].get();
3174  int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3175  SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3176  SmallVector<Location> blockArgLocs(sourceRank, result.location);
3177 
3178  // `builder.createBlock` changes the insertion point within the block. Create
3179  // a guard to reset the insertion point of the builder after it is destroyed.
3180  OpBuilder::InsertionGuard guard(b);
3181  b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3182  b.create<tensor::YieldOp>(result.location, constantPadValue);
3183 }
3184 
3185 llvm::SmallBitVector PadOp::getPaddedDims() {
3186  llvm::SmallBitVector paddedDims(getSourceType().getRank());
3187  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3188  for (const auto &en : enumerate(paddingWidths))
3189  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3190  paddedDims.set(en.index());
3191  };
3192  extractPaddedDims(getMixedLowPad());
3193  extractPaddedDims(getMixedHighPad());
3194  return paddedDims;
3195 }
3196 
3197 namespace {
3198 // Folds tensor.pad when padding is static zeros and the attribute
3199 // doesn't request otherwise.
3200 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3202 
3203  LogicalResult matchAndRewrite(PadOp padTensorOp,
3204  PatternRewriter &rewriter) const override {
3205  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3206  return failure();
3207  if (padTensorOp.getNofold())
3208  return failure();
3209  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3210  padTensorOp, padTensorOp.getResult().getType(),
3211  padTensorOp.getSource());
3212  return success();
3213  }
3214 };
3215 
3216 // Fold CastOp into PadOp when adding static information.
3217 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3219 
3220  LogicalResult matchAndRewrite(PadOp padTensorOp,
3221  PatternRewriter &rewriter) const override {
3222  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3223  if (!tensor::canFoldIntoConsumerOp(castOp))
3224  return failure();
3225 
3226  auto newResultType = PadOp::inferResultType(
3227  llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3228  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3229  padTensorOp.getResultType().getShape());
3230 
3231  if (newResultType == padTensorOp.getResultType()) {
3232  rewriter.modifyOpInPlace(padTensorOp, [&]() {
3233  padTensorOp.getSourceMutable().assign(castOp.getSource());
3234  });
3235  } else {
3236  auto newOp = rewriter.create<PadOp>(
3237  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3238  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3239  padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3240  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3241  IRMapping mapper;
3242  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3243 
3244  rewriter.replaceOpWithNewOp<tensor::CastOp>(
3245  padTensorOp, padTensorOp.getResultType(), newOp);
3246  }
3247  return success();
3248  }
3249 };
3250 
3251 // Fold CastOp using the result of PadOp back into the latter if it adds
3252 // static information.
3253 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3255 
3256  LogicalResult matchAndRewrite(PadOp padTensorOp,
3257  PatternRewriter &rewriter) const override {
3258  if (!padTensorOp.getResult().hasOneUse())
3259  return failure();
3260  auto tensorCastOp =
3261  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3262  if (!tensorCastOp)
3263  return failure();
3264  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
3265  tensorCastOp.getDest().getType()))
3266  return failure();
3267 
3268  auto replacementOp = rewriter.create<PadOp>(
3269  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3270  padTensorOp.getSource(), padTensorOp.getStaticLow(),
3271  padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3272  padTensorOp.getHigh(), padTensorOp.getNofold(),
3273  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3274  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3275 
3276  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3277  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3278  return success();
3279  }
3280 };
3281 
3282 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3283 /// different dimensions. The pattern applies if the following preconditions
3284 /// hold:
3285 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
3286 /// 2) the tensor::ExtractSliceOps have only unit-strides,
3287 /// 3) the tensor::PadOps perform only high-padding,
3288 /// 4) the tensor::PadOps have the same constant padding value,
3289 /// 5) the tensor::PadOps do not have common padding dimensions,
3290 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3291 /// zero-offset for every dimension.
3292 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3293 /// the
3294 /// padded source dimensions.
3295 ///
3296 /// Example:
3297 ///
3298 /// ```mlir
3299 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3300 /// : tensor<64x64xf32> to tensor<?x64xf32>
3301 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3302 /// } : tensor<?x64xf32> to tensor<8x64xf32>
3303 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3304 /// : tensor<8x64xf32> to tensor<8x?xf32>
3305 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3306 /// } : tensor<8x?xf32> to tensor<8x4xf32>
3307 /// ```
3308 ///
3309 /// folds into:
3310 ///
3311 /// ```mlir
3312 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3313 /// : tensor<64x64xf32> to tensor<?x?xf32>
3314 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3315 /// } : tensor<?x?xf32> to tensor<8x4xf32>
3316 /// ```
3317 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3319 
3320  LogicalResult matchAndRewrite(PadOp padOp,
3321  PatternRewriter &rewriter) const override {
3322  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3323  if (!innerSliceOp)
3324  return failure();
3325  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3326  if (!outerPadOp || outerPadOp.getNofold())
3327  return failure();
3328  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3329  if (!outerSliceOp)
3330  return failure();
3331 
3332  // 1) Fail if the chain is rank-reducing.
3333  int64_t rank = padOp.getSourceType().getRank();
3334  if (outerSliceOp.getSourceType().getRank() != rank) {
3335  return rewriter.notifyMatchFailure(padOp,
3336  "cannot fold rank-reducing chain");
3337  }
3338 
3339  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3340  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3341  return rewriter.notifyMatchFailure(
3342  padOp, "cannot fold non-unit stride ExtractSliceOps");
3343  }
3344 
3345  // 3) Fail if the tensor::PadOps have non-zero low padding.
3346  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3347  return rewriter.notifyMatchFailure(padOp,
3348  "cannot fold PadOps with low padding");
3349  }
3350 
3351  // 4) Fail if the tensor::PadOps padding values do not match.
3352  Attribute innerAttr, outerAttr;
3353  Value innerValue = padOp.getConstantPaddingValue();
3354  Value outerValue = outerPadOp.getConstantPaddingValue();
3355  if (!innerValue || !outerValue ||
3356  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
3357  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
3358  innerAttr != outerAttr) {
3359  return rewriter.notifyMatchFailure(
3360  padOp, "cannot fold PadOps with different padding values");
3361  }
3362 
3363  // 5) Fail if a dimension is padded by both tensor::PadOps.
3364  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3365  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3366  if (innerDims.anyCommon(outerDims)) {
3367  return rewriter.notifyMatchFailure(
3368  padOp, "cannot fold PadOps with common padding dimensions");
3369  }
3370 
3371  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3372  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3373  // for every dimension, and use the offset the other pair. Fail if no
3374  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3375  // exists.
3376  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3377  for (auto en : enumerate(newOffsets)) {
3378  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3379  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3380  if (!innerDims.test(en.index()) &&
3381  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3382  en.value() = outerOffset;
3383  continue;
3384  }
3385  if (!outerDims.test(en.index()) &&
3386  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3387  en.value() = innerOffset;
3388  continue;
3389  }
3390  return rewriter.notifyMatchFailure(
3391  padOp, "cannot find zero-offset and zero-padding pair");
3392  }
3393 
3394  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3395  // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3396  // outer tensor::PadOp and fail if the size of the inner
3397  // tensor::ExtractSliceOp does not match the size of the padded dimension.
3398  // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3399  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3400  for (auto en : enumerate(newSizes)) {
3401  if (!outerDims.test(en.index()))
3402  continue;
3403  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3404  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3405  assert(!ShapedType::isDynamic(sourceSize) &&
3406  "expected padded dimension to have a static size");
3407  if (getConstantIntValue(sliceSize) != sourceSize) {
3408  return rewriter.notifyMatchFailure(
3409  padOp, "cannot fold since the inner ExtractSliceOp size does not "
3410  "match the size of the outer padding");
3411  }
3412  en.value() = outerSliceOp.getMixedSizes()[en.index()];
3413  }
3414 
3415  // Combine the high paddings of the two tensor::PadOps.
3416  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3417  for (auto en : enumerate(newHighPad)) {
3418  if (innerDims.test(en.index()))
3419  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3420  if (outerDims.test(en.index()))
3421  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3422  }
3423 
3424  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3425  // the two paddings in one step.
3426  auto newSliceOp = rewriter.create<ExtractSliceOp>(
3427  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3428  innerSliceOp.getMixedStrides());
3429  auto newPadOp = rewriter.create<PadOp>(
3430  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3431  padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3432  getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3433  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3434  newPadOp.getRegion().begin());
3435  rewriter.replaceOp(padOp, newPadOp.getResult());
3436  return success();
3437  }
3438 };
3439 
3440 struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3442 
3443  LogicalResult matchAndRewrite(PadOp padTensorOp,
3444  PatternRewriter &rewriter) const override {
3445  Value input = padTensorOp.getSource();
3446  if (!llvm::isa<RankedTensorType>(input.getType()))
3447  return failure();
3448  auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3449  auto inputRank = inputDims.size();
3450 
3451  auto oldResultType =
3452  dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3453  if (!oldResultType)
3454  return failure();
3455 
3456  auto outputDims = oldResultType.getShape();
3457 
3458  // Extract the static info from the high and low operands.
3459  SmallVector<int64_t> constOperandsLow;
3460  SmallVector<Value> newLows;
3461  for (auto operand : padTensorOp.getLow()) {
3462  APSInt intOp;
3463  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3464  constOperandsLow.push_back(ShapedType::kDynamic);
3465  newLows.push_back(operand);
3466  continue;
3467  }
3468  constOperandsLow.push_back(intOp.getExtValue());
3469  }
3470  SmallVector<int64_t> constOperandsHigh;
3471  SmallVector<Value> newHighs;
3472  for (auto operand : padTensorOp.getHigh()) {
3473  APSInt intOp;
3474  if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3475  constOperandsHigh.push_back(ShapedType::kDynamic);
3476  newHighs.push_back(operand);
3477  continue;
3478  }
3479  constOperandsHigh.push_back(intOp.getExtValue());
3480  }
3481 
3482  SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3483  SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3484 
3485  // Verify the op is well-formed.
3486  if (inputDims.size() != outputDims.size() ||
3487  inputDims.size() != constLow.size() ||
3488  inputDims.size() != constHigh.size())
3489  return failure();
3490 
3491  auto lowCount = 0;
3492  auto highCount = 0;
3493  for (size_t i = 0; i < inputRank; i++) {
3494  if (constLow[i] == ShapedType::kDynamic)
3495  constLow[i] = constOperandsLow[lowCount++];
3496  if (constHigh[i] == ShapedType::kDynamic)
3497  constHigh[i] = constOperandsHigh[highCount++];
3498  }
3499 
3500  auto staticLow = ArrayRef<int64_t>(constLow);
3501  auto staticHigh = ArrayRef<int64_t>(constHigh);
3502 
3503  // Calculate the output sizes with the static information.
3504  SmallVector<int64_t> newOutDims;
3505  for (size_t i = 0; i < inputRank; i++) {
3506  if (outputDims[i] == ShapedType::kDynamic) {
3507  newOutDims.push_back(
3508  (staticLow[i] == ShapedType::kDynamic ||
3509  staticHigh[i] == ShapedType::kDynamic ||
3510  inputDims[i] == ShapedType::kDynamic
3511  ? ShapedType::kDynamic
3512  : inputDims[i] + staticLow[i] + staticHigh[i]));
3513  } else {
3514  newOutDims.push_back(outputDims[i]);
3515  }
3516  }
3517 
3518  if (SmallVector<int64_t>(outputDims) == newOutDims ||
3519  llvm::all_of(newOutDims,
3520  [&](int64_t x) { return x == ShapedType::kDynamic; }))
3521  return failure();
3522 
3523  // Rewrite the op using the new static type.
3524  auto newResultType = RankedTensorType::get(
3525  newOutDims, padTensorOp.getType().getElementType());
3526  auto newOp = rewriter.create<PadOp>(
3527  padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3528  newLows, newHighs, padTensorOp.getNofold(),
3529  getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3530 
3531  IRMapping mapper;
3532  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3533  rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3534  newOp);
3535 
3536  return success();
3537  }
3538 };
3539 
3540 /// Folds a chain of `tensor.pad` ops with the same constant padding value.
3541 ///
3542 /// Example:
3543 ///
3544 /// ```mlir
3545 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3546 /// tensor.yield %val
3547 /// } : tensor<1x2xf32> to tensor<2x5xf32>
3548 /// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3549 /// tensor.yield %val
3550 /// } : tensor<1x5xf32> to tensor<5x7xf32>
3551 /// ```
3552 ///
3553 /// folds into:
3554 ///
3555 /// ```mlir
3556 /// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3557 /// tensor.yield %val
3558 /// } : tensor<1x2xf32> to tensor<5x7xf32>
3559 /// ```
3560 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3562 
3563  LogicalResult matchAndRewrite(tensor::PadOp padOp,
3564  PatternRewriter &rewriter) const override {
3565  if (padOp.getNofold()) {
3566  return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3567  }
3568 
3569  auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3570  if (!producerPad || producerPad.getNofold()) {
3571  return rewriter.notifyMatchFailure(
3572  padOp, "producer is not a foldable tensor.pad op");
3573  }
3574 
3575  // Fail if the tensor::PadOps padding values do not match.
3576  Value consumerPadValue = padOp.getConstantPaddingValue();
3577  Value producerPadValue = producerPad.getConstantPaddingValue();
3578  if (!consumerPadValue || !producerPadValue ||
3579  consumerPadValue != producerPadValue) {
3580  return rewriter.notifyMatchFailure(
3581  padOp,
3582  "cannot fold PadOps with different or non-constant padding values");
3583  }
3584 
3585  Location loc = padOp.getLoc();
3586  AffineExpr d0, d1;
3587  bindDims(rewriter.getContext(), d0, d1);
3588 
3589  // Combine the low/high paddings of the two tensor::PadOps.
3590  auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3591  ArrayRef<OpFoldResult> producerPaddings) {
3592  SmallVector<OpFoldResult> sumPaddings;
3593  for (auto [consumerIndex, producerIndex] :
3594  llvm::zip_equal(consumerPaddings, producerPaddings)) {
3595  sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3596  rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3597  }
3598  return sumPaddings;
3599  };
3600 
3601  SmallVector<OpFoldResult> newHighPad =
3602  addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3603  SmallVector<OpFoldResult> newLowPad =
3604  addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3605 
3606  auto newPadOp = rewriter.create<tensor::PadOp>(
3607  padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3608  newLowPad, newHighPad, padOp.getNofold(),
3609  getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3610  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3611  newPadOp.getRegion().begin());
3612  rewriter.replaceOp(padOp, newPadOp.getResult());
3613  return success();
3614  }
3615 };
3616 
3617 } // namespace
3618 
3619 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3620  MLIRContext *context) {
3621  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3622  FoldOrthogonalPaddings, FoldStaticPadding,
3623  FoldConsecutiveConstantPadding>(context);
3624 }
3625 
3626 /// Return the padding value of the PadOp if it constant. In this context,
3627 /// "constant" means an actual constant or "defined outside of the block".
3628 ///
3629 /// Values are considered constant in three cases:
3630 /// - A ConstantLike value.
3631 /// - A basic block argument from a different block.
3632 /// - A value defined outside of the block.
3633 ///
3634 /// If the padding value is not constant, an empty Value is returned.
3635 Value PadOp::getConstantPaddingValue() {
3636  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3637  if (!yieldOp)
3638  return {};
3639  Value padValue = yieldOp.getValue();
3640  // Check if yield value is a constant.
3641  if (matchPattern(padValue, m_Constant()))
3642  return padValue;
3643  // Check if yield value is defined inside the PadOp block.
3644  if (padValue.getParentBlock() == &getRegion().front())
3645  return {};
3646  // Else: Yield value defined outside of the PadOp block.
3647  return padValue;
3648 }
3649 
3650 OpFoldResult PadOp::fold(FoldAdaptor) {
3651  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3652  !getNofold())
3653  return getSource();
3654  return {};
3655 }
3656 
3657 //===----------------------------------------------------------------------===//
3658 // ParallelInsertSliceOp
3659 //===----------------------------------------------------------------------===//
3660 
3661 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3662  ParallelCombiningOpInterface parallelCombiningParent =
3663  getParallelCombiningParent();
3664  for (const auto &it :
3665  llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3666  Operation &nextOp = it.value();
3667  if (&nextOp == getOperation())
3668  return parallelCombiningParent.getParentResult(it.index());
3669  }
3670  llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3671 }
3672 
3673 // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3674 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3675  Value source, Value dest,
3676  ArrayRef<OpFoldResult> offsets,
3677  ArrayRef<OpFoldResult> sizes,
3678  ArrayRef<OpFoldResult> strides,
3679  ArrayRef<NamedAttribute> attrs) {
3680  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3681  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3682  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3683  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3684  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3685  result.addAttributes(attrs);
3686  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3687  dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3688  b.getDenseI64ArrayAttr(staticSizes),
3689  b.getDenseI64ArrayAttr(staticStrides));
3690 }
3691 
3692 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3693 /// packed into a Range vector.
3694 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3695  Value source, Value dest,
3696  ArrayRef<Range> ranges,
3697  ArrayRef<NamedAttribute> attrs) {
3698  auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3699  build(b, result, source, dest, offsets, sizes, strides, attrs);
3700 }
3701 
3702 // Build a ParallelInsertSliceOp with dynamic entries.
3703 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3704  Value source, Value dest, ValueRange offsets,
3705  ValueRange sizes, ValueRange strides,
3706  ArrayRef<NamedAttribute> attrs) {
3707  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3708  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3709  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3710  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3711  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3712  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3713  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3714 }
3715 
3716 LogicalResult ParallelInsertSliceOp::verify() {
3717  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3718  return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3719  << *(getOperation()->getParentOp());
3720 
3721  RankedTensorType expectedType;
3722  SliceVerificationResult result =
3723  verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3724  getStaticSizes(), getStaticStrides(), &expectedType);
3725  return produceSliceErrorMsg(result, *this, expectedType);
3726 }
3727 
3728 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3729  RewritePatternSet &results, MLIRContext *context) {
3730  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3731  InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3732  InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3733 }
3734 
3735 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3736  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3737 }
3738 
3739 //===----------------------------------------------------------------------===//
3740 // ScatterOp
3741 //===----------------------------------------------------------------------===//
3742 
3743 void ScatterOp::getAsmResultNames(
3744  function_ref<void(Value, StringRef)> setNameFn) {
3745  setNameFn(getResult(), "scatter");
3746 }
3747 
3748 LogicalResult ScatterOp::verify() {
3749  int64_t destRank = getDestType().getRank();
3750  ArrayRef<int64_t> scatterDims = getScatterDims();
3751  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3752  getIndicesType().getShape(), destRank,
3753  "scatter", "dest")))
3754  return failure();
3755 
3756  if (!getUnique())
3757  return emitOpError("requires 'unique' attribute to be set");
3758  // TODO: we could also check statically that there are fewer leading index
3759  // tensor dims than the dest dims. If this is not the case, the unique
3760  // attribute cannot be true.
3761 
3762  // Use the GatherOp::inferResultType on the `dest` type and verify the
3763  // expected type matches the source type.
3764  RankedTensorType expectedSourceType = GatherOp::inferResultType(
3765  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3766  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3767  getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3768  if (getSourceType() != expectedSourceType &&
3769  getSourceType() != expectedRankReducedSourceType) {
3770  return emitOpError("source type "
3771  "mismatch: "
3772  "expected ")
3773  << expectedSourceType << " or its rank-reduced variant "
3774  << expectedRankReducedSourceType << " (got: " << getSourceType()
3775  << ")";
3776  }
3777 
3778  return success();
3779 }
3780 
3781 //===----------------------------------------------------------------------===//
3782 // SplatOp
3783 //===----------------------------------------------------------------------===//
3784 
3785 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3786  Type aggregateType, ValueRange dynamicSizes) {
3787  build(builder, result, aggregateType, element, dynamicSizes);
3788 }
3789 
3790 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3791  ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3792  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3793  build(builder, result, aggregateType, element, dynamicSizes);
3794 }
3795 
3796 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3797  ArrayRef<OpFoldResult> sizes) {
3798  SmallVector<int64_t> staticShape;
3799  SmallVector<Value> dynamicSizes;
3800  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3801  build(builder, result, element, staticShape, dynamicSizes);
3802 }
3803 
3804 void SplatOp::getAsmResultNames(
3805  function_ref<void(Value, StringRef)> setNameFn) {
3806  setNameFn(getResult(), "splat");
3807 }
3808 
3809 LogicalResult SplatOp::verify() {
3810  if (getType().getNumDynamicDims() != getDynamicSizes().size())
3811  return emitOpError("incorrect number of dynamic sizes, has ")
3812  << getDynamicSizes().size() << ", expected "
3813  << getType().getNumDynamicDims();
3814  return success();
3815 }
3816 
3817 LogicalResult
3819  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3820  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3821  unsigned ctr = 0;
3822  for (int64_t i = 0; i < getType().getRank(); ++i) {
3823  if (getType().isDynamicDim(i)) {
3824  reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
3825  } else {
3826  reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
3827  }
3828  }
3829  return success();
3830 }
3831 
3832 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
3833  auto constOperand = adaptor.getInput();
3834  if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3835  return {};
3836 
3837  // Do not fold if the splat is not statically shaped
3838  if (!getType().hasStaticShape())
3839  return {};
3840 
3841  // SplatElementsAttr::get treats single value for second arg as being a
3842  // splat.
3843  return SplatElementsAttr::get(getType(), {constOperand});
3844 }
3845 
3846 //===----------------------------------------------------------------------===//
3847 // PackOp/UnPackOp Common
3848 //===----------------------------------------------------------------------===//
3849 
3850 template <typename OpTy>
3851 static LogicalResult
3853  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3854  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3855  "applies to only pack or unpack operations");
3856  int64_t destRank = op.getDestRank();
3857  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
3858  reifiedReturnShapes[0] =
3859  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
3860  return success();
3861 }
3862 
3863 template <typename OpTy>
3865  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3866  "applies to only pack or unpack operations");
3867  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
3868  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
3869  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
3870  assert(tiles.size() == dimsToTile.size() &&
3871  "tiles must match indices of dimension to block");
3872  // bind the dimension `i` with the tile factor.
3873  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3874  dimAndTileMapping[dimsToTile[i]] = tiles[i];
3875  return dimAndTileMapping;
3876 }
3877 
3878 template <typename OpTy>
3880  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3881  "applies to only pack or unpack operations");
3882  Builder builder(op);
3883  SmallVector<OpFoldResult> mixedInnerTiles;
3884  unsigned dynamicValIndex = 0;
3885  for (int64_t staticTile : op.getStaticInnerTiles()) {
3886  if (!ShapedType::isDynamic(staticTile))
3887  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
3888  else
3889  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3890  }
3891  return mixedInnerTiles;
3892 }
3893 
3894 template <typename OpTy>
3896  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3897  "applies to only pack or unpack operations");
3898  SmallVector<Value> dynamicTiles;
3899  SmallVector<int64_t> staticTiles;
3900  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
3901  return staticTiles;
3902 }
3903 
3904 /// Returns true if `dimsPos` is invalid. It is invalid when:
3905 /// a) It contains duplicate.
3906 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
3907 /// c) The number of elements in `dimsPos` is > than `rank`.
3909  size_t rank) {
3910  size_t dimsPosSize = dimsPos.size();
3911  if (dimsPosSize > rank)
3912  return true;
3913  DenseSet<int64_t> uniqued;
3914  for (int64_t dim : dimsPos)
3915  uniqued.insert(dim);
3916  if (dimsPosSize != uniqued.size())
3917  return true;
3918  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3919  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
3920  });
3921 }
3922 
3923 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
3924 /// of the `limitShape`.
3925 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
3926  ArrayRef<int64_t> limitShape) {
3927  assert(
3928  sourceShape.size() == limitShape.size() &&
3929  "expected source shape rank, and limit of the shape to have same rank");
3930  return llvm::all_of(
3931  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3932  int64_t sourceExtent = std::get<0>(it);
3933  int64_t limit = std::get<1>(it);
3934  return ShapedType::isDynamic(sourceExtent) ||
3935  ShapedType::isDynamic(limit) || sourceExtent <= limit;
3936  });
3937 }
3938 
3939 template <typename OpTy>
3940 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
3941  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3942  "applies to only pack or unpack operations");
3943  Operation *op = packOrUnPack.getOperation();
3944 
3945  // Return true if we have a zero-value tile.
3946  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
3947  return llvm::any_of(
3948  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
3949  };
3950 
3951  // Verify tiles. Do not allow zero tiles.
3952  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
3953  if (hasZeros(mixedTiles))
3954  return op->emitError("invalid zero tile factor");
3955 
3956  // Verify inner_dims_pos and outer_dims_perm.
3957  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3958  ? packOrUnPack.getSourceType()
3959  : packOrUnPack.getDestType();
3960  size_t unpackedRank = unpackedType.getRank();
3961  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
3962  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
3963  if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
3964  return op->emitError("invalid inner_dims_pos vector");
3965  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
3966  return op->emitError("invalid outer_dims_perm vector");
3967  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3968  return op->emitError("outer_dims_perm must be a permutation or empty");
3969 
3970  // Tiling factors must be less than or equal to the input rank for pack (or
3971  // output rank for unpack), and must match the number of `inner_dims_pos`.
3972  if (mixedTiles.size() > unpackedRank) {
3973  return op->emitError("tiling factors must be less than or equal to the "
3974  "input rank for pack or output rank for unpack");
3975  }
3976  if (mixedTiles.size() != innerDimsPos.size()) {
3977  return op->emitError(
3978  "tiling factors must equal the number of dimensions to tile");
3979  }
3980 
3981  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3982  ? packOrUnPack.getDestType()
3983  : packOrUnPack.getSourceType();
3984  size_t packedRank = packedType.getRank();
3985  // Require output rank to match input rank + number of blocking factors.
3986  if (unpackedRank + mixedTiles.size() != packedRank) {
3987  return op->emitError(
3988  "packed rank must equal unpacked rank + tiling factors");
3989  }
3990 
3991  // Verify result shape is greater than the minimum expected
3992  // by the pack operation, and that the output shape
3993  // represents full tiles.
3994  RankedTensorType expectedPackedType = PackOp::inferPackedType(
3995  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3996  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3997  return op->emitError("the shape of output is not large enough to hold the "
3998  "packed data. Expected at least ")
3999  << expectedPackedType << ", got " << packedType;
4000  }
4001  if (!llvm::all_of(
4002  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4003  mixedTiles),
4004  [](std::tuple<int64_t, OpFoldResult> it) {
4005  int64_t shape = std::get<0>(it);
4006  if (Attribute attr =
4007  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4008  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4009  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4010  return shape == staticTileSize;
4011  }
4012  return ShapedType::isDynamic(shape);
4013  })) {
4014  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4015  "tiled dimension in the packed type");
4016  }
4017  return success();
4018 }
4019 
4020 namespace {
4021 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4022 /// various permutations to the op.
4023 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4024 // these. These may or may not become true foldings / canonicalizations
4025 // depending on how aggressive we want to be in automatically folding
4026 // transposes.
4027 struct PackOrUnPackTransposeResult {
4028  SmallVector<int64_t> innerDimsPos;
4029  SmallVector<OpFoldResult> innerTiles;
4030  SmallVector<int64_t> outerDimsPerm;
4031 };
4032 } // namespace
4033 
4034 template <typename OpTy>
4035 static PackOrUnPackTransposeResult
4037  ArrayRef<int64_t> innerPermutation,
4038  ArrayRef<int64_t> outerPermutation) {
4039  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4040  "applies to only pack or unpack operations");
4041  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4042  "some permutation must be non-empty");
4043  PackOrUnPackTransposeResult metadata;
4044  metadata.innerDimsPos =
4045  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4046  metadata.innerTiles =
4047  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4048  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4049  ? packOrUnPackOp.getSourceRank()
4050  : packOrUnPackOp.getDestRank();
4051  metadata.outerDimsPerm =
4052  packOrUnPackOp.getOuterDimsPerm().empty()
4053  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4054  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4055  if (!innerPermutation.empty()) {
4056  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4057  isPermutationVector(innerPermutation) &&
4058  "invalid inner permutation");
4059  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4060  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4061  }
4062  if (!outerPermutation.empty()) {
4063  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4064  isPermutationVector(outerPermutation) &&
4065  "invalid outer permutation");
4066  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4067  }
4068  return metadata;
4069 }
4070 
4071 //===----------------------------------------------------------------------===//
4072 // PackOp
4073 //===----------------------------------------------------------------------===//
4074 
4075 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4076  setNameFn(getResult(), "pack");
4077 }
4078 
4079 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4080  Value dest, ArrayRef<int64_t> innerDimsPos,
4081  ArrayRef<OpFoldResult> innerTiles,
4082  std::optional<Value> paddingValue,
4083  ArrayRef<int64_t> outerDimsPerm) {
4084  assert(innerDimsPos.size() == innerTiles.size() &&
4085  "number of tile sizes specified must match the specified number of "
4086  "original dimensions to be tiled");
4087  SmallVector<int64_t> staticTileSizes;
4088  SmallVector<Value> dynamicTileSizes;
4089  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4090  build(builder, state, dest.getType(), source, dest,
4091  paddingValue ? *paddingValue : nullptr,
4092  outerDimsPerm.empty() ? nullptr
4093  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4094  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4095  builder.getDenseI64ArrayAttr(staticTileSizes));
4096 }
4097 
4098 LogicalResult
4100  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4101  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4102 }
4103 
4104 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4105  return getDimAndTileMappingImpl(*this);
4106 }
4107 
4108 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4109  return getMixedTilesImpl(*this);
4110 }
4111 
4112 SmallVector<int64_t> PackOp::getStaticTiles() {
4113  return getStaticTilesImpl(*this);
4114 }
4115 
4116 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4117  ShapedType inputType = getSourceType();
4118  int64_t inputRank = inputType.getRank();
4119  return getDestType().getShape().take_front(inputRank);
4120 }
4121 
4122 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4123  auto innerDimsPos = getInnerDimsPos();
4124  auto packedShape = getDestType().getShape();
4126 
4127  for (auto index : innerDimsPos)
4128  res.push_back(packedShape[index]);
4129 
4130  return res;
4131 }
4132 
4133 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4134  ArrayRef<int64_t> innerDimsPos,
4135  ArrayRef<int64_t> outputShape,
4136  ArrayRef<int64_t> outerDimsPerm,
4137  ArrayRef<OpFoldResult> innerTiles) {
4138  SmallVector<int64_t> outputTileSizes(
4139  outputShape.take_front(inputShape.size()));
4140  if (!outerDimsPerm.empty()) {
4141  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4142  "expected output and outer_dims_perm to have same size");
4143  applyPermutationToVector(outputTileSizes,
4144  invertPermutationVector(outerDimsPerm));
4145  }
4146  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4147  if (ShapedType::isDynamic(inputShape[pos]))
4148  continue;
4149  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4150 
4151  if (!constantTile) {
4152  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4153  (inputShape[pos] % outputTileSizes[pos] != 0))
4154  return true;
4155  } else if (inputShape[pos] % (*constantTile) != 0) {
4156  return true;
4157  }
4158  }
4159  return false;
4160 }
4161 
4162 LogicalResult PackOp::verify() {
4163  if (failed(commonVerifierPackAndUnPackOp(*this)))
4164  return failure();
4165 
4166  // Verify padding value, and bail out if the tile does not divide the
4167  // dimension fully. In the case of dynamic tile factors or dimensions, having
4168  // a partial tile is undefined behavior.
4169  auto paddingValue = getPaddingValue();
4170  if (paddingValue &&
4171  paddingValue.getType() != getSourceType().getElementType()) {
4172  return emitOpError("expected padding_value has ")
4173  << getSourceType().getElementType()
4174  << " but got: " << paddingValue.getType();
4175  }
4176 
4177  if (!paddingValue &&
4178  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4179  getDestType().getShape(), getOuterDimsPerm(),
4180  getMixedTiles())) {
4181  return emitOpError(
4182  "invalid tile factor or output size provided. Only full tiles are "
4183  "supported when padding_value is not set");
4184  }
4185  return success();
4186 }
4187 
4188 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4189 /// Value's to kDynamic, even if they are arith.constant values.
4190 static SmallVector<int64_t>
4192  SmallVector<int64_t> result;
4193  for (auto o : ofrs) {
4194  // Have to do this first, as getConstantIntValue special-cases constants.
4195  if (llvm::dyn_cast_if_present<Value>(o))
4196  result.push_back(ShapedType::kDynamic);
4197  else
4198  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4199  }
4200  return result;
4201 }
4202 
4203 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4204 /// the packed type. Having a shared helper helps implement these two methods in
4205 /// a way that ensures that they agree on which dimensions are dynamic.
4207  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4208  ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4209  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4210  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4211  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4212  continue;
4213  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4214  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4215  continue;
4216  }
4217  resultShape[tiledDim.value()] = divideCeilSigned(
4218  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4219  }
4220 
4221  // Swap tile loops if outer_dims_perm is available.
4222  if (!outerDimsPerm.empty())
4223  applyPermutationToVector(resultShape, outerDimsPerm);
4224 
4225  // Append the inner tile dimensions.
4226  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4227  return resultShape;
4228 }
4229 
4230 SmallVector<OpFoldResult> PackOp::getResultShape(
4231  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4232  ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4233  ArrayRef<int64_t> outerDimsPerm) {
4234  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4235 
4236  AffineExpr s0, s1;
4237  bindSymbols(builder.getContext(), s0, s1);
4238  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4239  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4240  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4241  builder, loc, ceilDivExpr,
4242  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4243  }
4244  if (!outerDimsPerm.empty())
4245  applyPermutationToVector(resultDims, outerDimsPerm);
4246  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4247 
4248  SmallVector<int64_t> resultTypeShape =
4250  asShapeWithAnyValueAsDynamic(innerTileSizes),
4251  innerDimsPos, outerDimsPerm);
4252 
4253  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4254  // result type shape says it's a dynamic dim. This is needed as callers may
4255  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4256  // dynamic dims returned by that.
4257  for (unsigned i = 0; i < resultDims.size(); ++i) {
4258  if (!ShapedType::isDynamic(resultTypeShape[i]))
4259  continue;
4260  resultDims[i] =
4261  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4262  }
4263 
4264  return resultDims;
4265 }
4266 
4267 /// Get the expected packed type based on source type, tile factors, position of
4268 /// the inner tiles and permutation of the outer tiled loop.
4269 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4270  ArrayRef<int64_t> innerTileSizes,
4271  ArrayRef<int64_t> innerDimsPos,
4272  ArrayRef<int64_t> outerDimsPerm) {
4274  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4275  return RankedTensorType::get(resultShape, sourceType.getElementType());
4276 }
4277 
4278 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4279  ArrayRef<OpFoldResult> innerTileSizes,
4280  ArrayRef<int64_t> innerDimsPos,
4281  ArrayRef<int64_t> outerDimsPerm) {
4282  AffineExpr dim0, dim1;
4283  bindDims(b.getContext(), dim0, dim1);
4284  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4285  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4286  {v1, v2});
4287  };
4288 
4289  SmallVector<OpFoldResult> mixedSizes;
4290  for (auto [index, value] : llvm::enumerate(
4291  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4292  if (ShapedType::isDynamic(value))
4293  mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
4294  else
4295  mixedSizes.push_back(b.getIndexAttr(value));
4296  }
4297  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4298  int64_t dimPos = std::get<0>(it);
4299  OpFoldResult tileSize = std::get<1>(it);
4300  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4301  }
4302  if (!outerDimsPerm.empty())
4303  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4304 
4305  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4306  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4307  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4308 }
4309 
4310 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4311  ArrayRef<int64_t> innerPermutation,
4312  ArrayRef<int64_t> outerPermutation) {
4313  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4314  *this, innerPermutation, outerPermutation);
4315  Value transposedDest =
4316  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4317  metadata.innerDimsPos, metadata.outerDimsPerm);
4318  return b.create<PackOp>(loc, getSource(), transposedDest,
4319  metadata.innerDimsPos, metadata.innerTiles,
4320  getPaddingValue(), metadata.outerDimsPerm);
4321 }
4322 
4323 /// Returns true if the tiles and the tiled dims are constant.
4324 template <typename OpTy>
4326  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4327  "applies to only pack or unpack operations");
4328  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4329  ? op.getDestType()
4330  : op.getSourceType();
4331  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4332  for (auto [dimDest, tile] : llvm::zip(
4333  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4334  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4335  if (!constTileSize || ShapedType::isDynamic(dimDest))
4336  return false;
4337  }
4338  return true;
4339 }
4340 
4341 Speculation::Speculatability PackOp::getSpeculatability() {
4342  if (getPaddingValue())
4344 
4345  // The verifier rejects already operations if we can statically prove that the
4346  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4347  // to have constant tiles and tiled inner dimensions.
4348  if (!areTilesAndTiledDimsAllConstant(*this))
4350 
4352 }
4353 
4354 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4355 // dimensions for pack and unpack.
4356 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4357  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4358  return false;
4359  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4360  return true;
4361  // Outer dims permutation is optional.
4362  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4363  // identity permutation.
4364  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4365  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4366 }
4367 
4368 // Return true if pack and unpack have the same tiles.
4369 // Same SSA values or same integer constants.
4370 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4371  auto packTiles = packOp.getMixedTiles();
4372  auto unPackTiles = unPackOp.getMixedTiles();
4373  if (packTiles.size() != unPackTiles.size())
4374  return false;
4375  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4376  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4377  return false;
4378  }
4379  return true;
4380 }
4381 
4382 /// Returns true if the pack op does not need a padding value.
4383 static bool paddingIsNotNeeded(PackOp op) {
4384  auto srcType = op.getSourceType();
4385  if (llvm::any_of(op.getInnerDimsPos(),
4386  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4387  return false;
4388  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4389  return false;
4390  return !PackOp::requirePaddingValue(
4391  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4392  op.getOuterDimsPerm(), op.getMixedTiles());
4393 }
4394 
4395 /// Returns true if the `srcShape` or `destShape` is different from the one in
4396 /// `packOp` and populates each with the inferred static shape.
4397 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4398  SmallVectorImpl<int64_t> &destShape) {
4399  bool changeNeeded = false;
4400  srcShape.assign(packOp.getSourceType().getShape().begin(),
4401  packOp.getSourceType().getShape().end());
4402  destShape.assign(packOp.getDestType().getShape().begin(),
4403  packOp.getDestType().getShape().end());
4404  llvm::SmallSetVector<int64_t, 4> innerDims;
4405  innerDims.insert(packOp.getInnerDimsPos().begin(),
4406  packOp.getInnerDimsPos().end());
4407  SmallVector<int64_t> inverseOuterDimsPerm;
4408  if (!packOp.getOuterDimsPerm().empty())
4409  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4410  int srcRank = packOp.getSourceRank();
4411  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4412  if (innerDims.contains(i))
4413  continue;
4414  int64_t srcPos = i;
4415  int64_t destPos = i;
4416  if (!inverseOuterDimsPerm.empty())
4417  destPos = inverseOuterDimsPerm[srcPos];
4418  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4419  ShapedType::isDynamic(destShape[destPos])) {
4420  continue;
4421  }
4422  int64_t size = srcShape[srcPos];
4423  if (ShapedType::isDynamic(size))
4424  size = destShape[destPos];
4425  srcShape[srcPos] = size;
4426  destShape[destPos] = size;
4427  changeNeeded = true;
4428  }
4429  return changeNeeded;
4430 }
4431 
4432 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4433  // Fold an pack(unpack(x)) to x.
4434  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4435  if (unPackOp.getSourceType() != packOp.getDestType())
4436  return failure();
4437  if (packOp.getPaddingValue() ||
4438  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4439  !haveSameTiles(packOp, unPackOp))
4440  return failure();
4441  rewriter.replaceOp(packOp, unPackOp.getSource());
4442  return success();
4443  }
4444 
4445  // Fold optional PaddingValue operand away if padding is not needed.
4446  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4447  rewriter.startOpModification(packOp);
4448  packOp.getPaddingValueMutable().clear();
4449  rewriter.finalizeOpModification(packOp);
4450  return success();
4451  }
4452 
4453  // Insert tensor.cast ops if static shape inference is available..
4454  SmallVector<int64_t> srcShape, destShape;
4455  if (inferStaticShape(packOp, srcShape, destShape)) {
4456  Location loc = packOp.getLoc();
4457  Value source = packOp.getSource();
4458  if (srcShape != packOp.getSourceType().getShape()) {
4459  auto newSrcType = packOp.getSourceType().clone(srcShape);
4460  source =
4461  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4462  }
4463  Value dest = packOp.getDest();
4464  RankedTensorType originalResultType = packOp.getDestType();
4465  bool needUpdateDestType = (destShape != originalResultType.getShape());
4466  if (needUpdateDestType) {
4467  auto newDestType = packOp.getDestType().clone(destShape);
4468  dest =
4469  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4470  }
4471  rewriter.modifyOpInPlace(packOp, [&] {
4472  packOp.getSourceMutable().assign(source);
4473  packOp.getDestMutable().assign(dest);
4474  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4475  });
4476  // Insert a cast if needed
4477  if (needUpdateDestType) {
4478  rewriter.setInsertionPointAfter(packOp);
4479  auto castOp =
4480  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4481  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4482  }
4483  return success();
4484  }
4485 
4486  return failure();
4487 }
4488 
4489 template <typename PackOrUnpackOp>
4490 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4491  RankedTensorType packedTensorType) {
4492  static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4493  std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4494  "Function meant for pack/unpack");
4495  // This is a pad if packing only adds ones and we don't transpose dimensions.
4496 
4497  // Check that we are not transposing any dimensions.
4498  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4499  int64_t numPackedDims = innerDimsPos.size();
4500  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4501  if (orderedDims != innerDimsPos) {
4502  // Dimensions don't happen in order.
4503  return false;
4504  }
4505 
4506  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4507  int64_t packedRank = packedTensorType.getRank();
4508  // At this point we know that we are taking numPackedDims outer
4509  // dimensions and pushing them all the way as the inner most dimensions.
4510  // What's left on the outer most dimensions is, in this order:
4511  // - the factor of the packed dimensions, then
4512  // - the untouched dimensions
4513  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4514  // if all the dimensions that bubble outerward are ones.
4515  // Therefore check that all the dimensions but the numPackedDims inner most
4516  // ones are ones.
4517  return llvm::all_of(
4518  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4519  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4520 }
4521 
4522 bool PackOp::isLikePad() {
4523  auto packedTensorType =
4524  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4525  return isLikePadUnPad(*this, packedTensorType);
4526 }
4527 
4528 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4529  std::optional<Attribute> paddingValue;
4530  if (auto pad = adaptor.getPaddingValue())
4531  paddingValue = pad;
4532  if (OpFoldResult reshapedSource = reshapeConstantSource(
4533  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4534  getDestType(), paddingValue))
4535  return reshapedSource;
4536  return {};
4537 }
4538 
4539 //===----------------------------------------------------------------------===//
4540 // UnPackOp
4541 //===----------------------------------------------------------------------===//
4542 
4543 void UnPackOp::getAsmResultNames(
4544  function_ref<void(Value, StringRef)> setNameFn) {
4545  setNameFn(getResult(), "unpack");
4546 }
4547 
4548 LogicalResult
4550  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4551  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4552 }
4553 
4554 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4555  return getDimAndTileMappingImpl(*this);
4556 }
4557 
4558 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4559  return getMixedTilesImpl(*this);
4560 }
4561 
4562 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4563  return getStaticTilesImpl(*this);
4564 }
4565 
4566 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
4567  ShapedType destType = getDestType();
4568  int64_t destRank = destType.getRank();
4569  return getSourceType().getShape().take_front(destRank);
4570 }
4571 
4572 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
4573  auto innerDimsPos = getInnerDimsPos();
4574  auto packedShape = getSourceType().getShape();
4576 
4577  for (auto index : innerDimsPos)
4578  res.push_back(packedShape[index]);
4579 
4580  return res;
4581 }
4582 
4583 LogicalResult UnPackOp::verify() {
4584  return commonVerifierPackAndUnPackOp(*this);
4585 }
4586 
4587 Speculation::Speculatability UnPackOp::getSpeculatability() {
4588  // See PackOp::getSpeculatability.
4589  if (!areTilesAndTiledDimsAllConstant(*this))
4591 
4593 }
4594 
4595 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4596  Value dest, ArrayRef<int64_t> innerDimsPos,
4597  ArrayRef<OpFoldResult> innerTiles,
4598  ArrayRef<int64_t> outerDimsPerm) {
4599  assert(innerDimsPos.size() == innerTiles.size() &&
4600  "number of tile sizes specified must match the specified number of "
4601  "original dimensions to be tiled");
4602  SmallVector<int64_t> staticTileSizes;
4603  SmallVector<Value> dynamicTileSizes;
4604  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4605  build(builder, state, dest.getType(), source, dest,
4606  outerDimsPerm.empty() ? nullptr
4607  : builder.getDenseI64ArrayAttr(outerDimsPerm),
4608  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4609  builder.getDenseI64ArrayAttr(staticTileSizes));
4610 }
4611 
4612 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4613  Value source,
4614  ArrayRef<OpFoldResult> innerTileSizes,
4615  ArrayRef<int64_t> innerDimsPos,
4616  ArrayRef<int64_t> outerDimsPerm) {
4617  AffineExpr sym0, sym1;
4618  bindSymbols(b.getContext(), sym0, sym1);
4619  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4620  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4621  };
4622 
4623  SmallVector<OpFoldResult> mixedSizes;
4624  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4625  for (auto i :
4626  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4627  if (srcType.isDynamicDim(i))
4628  mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
4629  else
4630  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4631  }
4632  if (!outerDimsPerm.empty()) {
4633  applyPermutationToVector<OpFoldResult>(
4634  mixedSizes, invertPermutationVector(outerDimsPerm));
4635  }
4636 
4637  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4638  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4639 
4640  auto elemType = srcType.getElementType();
4641  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4642 }
4643 
4644 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4645  Value transposedSource,
4646  ArrayRef<int64_t> innerPermutation,
4647  ArrayRef<int64_t> outerPermutation) {
4648  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4649  *this, innerPermutation, outerPermutation);
4650  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4651  metadata.innerDimsPos, metadata.innerTiles,
4652  metadata.outerDimsPerm);
4653 }
4654 
4655 /// Returns true if the `srcShape` or `destShape` is different from the one in
4656 /// `op` and populates each with the inferred static shape.
4657 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4658  SmallVectorImpl<int64_t> &destShape) {
4659  bool changeNeeded = false;
4660  srcShape.assign(op.getSourceType().getShape().begin(),
4661  op.getSourceType().getShape().end());
4662  destShape.assign(op.getDestType().getShape().begin(),
4663  op.getDestType().getShape().end());
4664  llvm::SmallSetVector<int64_t, 4> innerDims;
4665  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4666  SmallVector<int64_t> inverseOuterDimsPerm;
4667  if (!op.getOuterDimsPerm().empty())
4668  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4669  int destRank = op.getDestRank();
4670  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4671  if (innerDims.contains(i))
4672  continue;
4673  int64_t srcPos = i;
4674  int64_t destPos = i;
4675  if (!inverseOuterDimsPerm.empty())
4676  srcPos = inverseOuterDimsPerm[destPos];
4677  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4678  ShapedType::isDynamic(destShape[destPos])) {
4679  continue;
4680  }
4681  int64_t size = srcShape[srcPos];
4682  if (ShapedType::isDynamic(size))
4683  size = destShape[destPos];
4684  srcShape[srcPos] = size;
4685  destShape[destPos] = size;
4686  changeNeeded = true;
4687  }
4688  return changeNeeded;
4689 }
4690 
4691 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4692  PatternRewriter &rewriter) {
4693  /// unpack(pack(x)) -> x
4694  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4695  if (packOp.getSourceType() != unPackOp.getDestType())
4696  return failure();
4697  if (packOp.getPaddingValue() ||
4698  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4699  !haveSameTiles(packOp, unPackOp))
4700  return failure();
4701  rewriter.replaceOp(unPackOp, packOp.getSource());
4702  return success();
4703  }
4704  /// unpack(destinationStyleOp(x)) -> unpack(x)
4705  if (auto dstStyleOp =
4706  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4707  auto destValue = cast<OpResult>(unPackOp.getDest());
4708  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4709  rewriter.modifyOpInPlace(unPackOp,
4710  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4711  return success();
4712  }
4713 
4714  // Insert tensor.cast ops if static shape inference is available..
4715  SmallVector<int64_t> srcShape, destShape;
4716  if (inferStaticShape(unPackOp, srcShape, destShape)) {
4717  Location loc = unPackOp.getLoc();
4718  Value source = unPackOp.getSource();
4719  if (srcShape != unPackOp.getSourceType().getShape()) {
4720  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4721  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
4722  unPackOp.getSource());
4723  }
4724  Value dest = unPackOp.getDest();
4725  if (destShape != unPackOp.getDestType().getShape()) {
4726  auto newDestType = unPackOp.getDestType().clone(destShape);
4727  dest =
4728  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4729  }
4730  Value newOp = rewriter.create<tensor::UnPackOp>(
4731  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4732  unPackOp.getOuterDimsPerm());
4733  rewriter.replaceOpWithNewOp<tensor::CastOp>(
4734  unPackOp, unPackOp.getResult().getType(), newOp);
4735  return success();
4736  }
4737 
4738  return failure();
4739 }
4740 
4741 bool UnPackOp::isLikeUnPad() {
4742  RankedTensorType packedTensorType = getSourceType();
4743  return isLikePadUnPad(*this, packedTensorType);
4744 }
4745 
4746 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4747  if (OpFoldResult reshapedSource = reshapeConstantSource(
4748  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4749  getResult().getType()))
4750  return reshapedSource;
4751  return {};
4752 }
4753 
4754 //===----------------------------------------------------------------------===//
4755 // Common Canonicalizers and Folders.
4756 //===----------------------------------------------------------------------===//
4757 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4758  // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4759  // 2. Exclude DPS ops that are also LoopLike from this interface as they
4760  // might need special handling of attached regions.
4761  if (isa<InsertSliceOp>(op.getOperation()) ||
4762  isa<LoopLikeOpInterface>(op.getOperation()))
4763  return false;
4764 
4765  // If no operand comes from a tensor::CastOp and can be folded then fail.
4766  bool hasTensorCastOperand =
4767  llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4768  if (llvm::isa<BlockArgument>(opOperand.get()))
4769  return false;
4770  auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4771  return castOp && canFoldIntoConsumerOp(castOp);
4772  });
4773 
4774  return hasTensorCastOperand;
4775 }
4776 
4777 static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4778  SmallVector<Type> &newResTy) {
4779  SmallVector<Value> newOperands;
4780  newOperands.reserve(op->getNumOperands());
4781 
4782  // Assumes that the result has dpsInits followed by nonDpsInits.
4783  int64_t dpsInitIdx = 0;
4784  for (OpOperand &opOperand : op->getOpOperands()) {
4785  auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4786  bool fold = canFoldIntoConsumerOp(tensorCastOp);
4787  newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4788  if (op.isDpsInit(&opOperand) &&
4789  !llvm::isa<MemRefType>(newOperands.back().getType()))
4790  newResTy[dpsInitIdx++] = newOperands.back().getType();
4791  }
4792  return newOperands;
4793 }
4794 
4795 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4796 /// `tensor.cast` has source that is more static than the consuming op.
4797 ///
4798 /// Example:
4799 /// ```mlir
4800 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4801 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4802 /// ```
4803 ///
4804 /// folds into:
4805 ///
4806 /// ```mlir
4807 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4808 /// ```
4809 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4811 
4812  LogicalResult matchAndRewrite(PackOp op,
4813  PatternRewriter &rewriter) const override {
4814  if (!foldTensorCastPrecondition(op))
4815  return failure();
4816 
4817  SmallVector<Type> newResultTypes(op->getResultTypes());
4818  SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4819 
4820  // Get the updated mixed-tile-sizes attribute.
4821  SmallVector<OpFoldResult> newMixedTileSizes;
4822  for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4823  .getShape()
4824  .take_back(op.getMixedTiles().size()),
4825  op.getMixedTiles())) {
4826  int64_t shape = std::get<0>(it);
4827  if (shape == ShapedType::kDynamic) {
4828  newMixedTileSizes.push_back(std::get<1>(it));
4829  continue;
4830  }
4831 
4832  if (Attribute attr =
4833  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4834  // Already a constant
4835  newMixedTileSizes.push_back(std::get<1>(it));
4836  } else {
4837  int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
4838  assert(tileSize == shape && "tile size and dim size don't match!");
4839  (void)tileSize;
4840  newMixedTileSizes.push_back(
4841  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4842  }
4843  }
4844 
4845  // Clone op.
4846  PackOp newOp = rewriter.create<PackOp>(
4847  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4848  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4849  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4850 
4851  // Replace op.
4852  Value oldResult = op.getResult();
4853  Value newResult = newOp.getResult();
4854  Value replacement = (newResult.getType() != oldResult.getType())
4855  ? rewriter.create<tensor::CastOp>(
4856  op->getLoc(), oldResult.getType(), newResult)
4857  : newResult;
4858 
4859  rewriter.replaceOp(op, {replacement});
4860 
4861  return success();
4862  }
4863 };
4864 
4865 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4866 /// the `tensor.cast` has source that is more static than the consuming op.
4867 ///
4868 /// Example:
4869 /// ```mlir
4870 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4871 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
4872 /// ```
4873 ///
4874 /// folds into:
4875 ///
4876 /// ```mlir
4877 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
4878 /// ```
4879 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4880 /// can add the pattern to their canonicalizers.
4882  : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4884  DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4885 
4886  LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4887  PatternRewriter &rewriter) const override {
4888 
4889  // Reject tensor::PackOp - there's dedicated pattern for that instead.
4890  if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
4891  return failure();
4892 
4893  SmallVector<Type> newResultTypes(op->getResultTypes());
4894  SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4895 
4896  // Clone op
4897  auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4898 
4899  SmallVector<Value, 4> replacements;
4900  replacements.reserve(newOp->getNumResults());
4901  for (auto [oldResult, newResult] :
4902  llvm::zip(op->getResults(), newOp->getResults())) {
4903  if (newResult.getType() != oldResult.getType()) {
4904  replacements.push_back(rewriter.create<tensor::CastOp>(
4905  op->getLoc(), oldResult.getType(), newResult));
4906  } else {
4907  replacements.push_back(newResult);
4908  }
4909  }
4910  rewriter.replaceOp(op, replacements);
4911 
4912  return success();
4913  }
4914 };
4915 
4916 //===----------------------------------------------------------------------===//
4917 // TensorDialect
4918 //===----------------------------------------------------------------------===//
4919 
4920 void TensorDialect::getCanonicalizationPatterns(
4921  RewritePatternSet &results) const {
4922  results.add<FoldTensorCastPackOp>(getContext());
4924 }
4925 
4926 //===----------------------------------------------------------------------===//
4927 // TableGen'd op method definitions
4928 //===----------------------------------------------------------------------===//
4929 
4930 #define GET_OP_CLASSES
4931 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: TensorOps.cpp:4325
static SmallVector< Value > getNewOperands(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Definition: TensorOps.cpp:4777
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:386
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: TensorOps.cpp:4036
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
Definition: TensorOps.cpp:1346
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
Definition: TensorOps.cpp:2306
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: TensorOps.cpp:3864
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: TensorOps.cpp:3895
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: TensorOps.cpp:4383
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3032
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: TensorOps.cpp:4206
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: TensorOps.cpp:4191
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: TensorOps.cpp:3879
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:2727
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:2579
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:2601
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: TensorOps.cpp:4397
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: TensorOps.cpp:3925
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1604
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4370
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
Definition: TensorOps.cpp:2688
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: TensorOps.cpp:4490
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: TensorOps.cpp:3940
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
Definition: TensorOps.cpp:180
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: TensorOps.cpp:3852
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: TensorOps.cpp:3908
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
Definition: TensorOps.cpp:1127
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:3028
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
Definition: TensorOps.cpp:136
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: TensorOps.cpp:4356
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
Definition: TensorOps.cpp:2747
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:1828
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
Definition: TensorOps.cpp:4757
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:410
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:518
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:261
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:272
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:357
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2539
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:348
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:318
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3005
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2626
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:124
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:75
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:266
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming tensor::PackOp op if the tensor.cast has source that is more ...
Definition: TensorOps.cpp:4809
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4812
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
Definition: TensorOps.cpp:4882
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
Definition: TensorOps.cpp:4886
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:2558
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:2559
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:2546
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:2547
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)