MLIR  14.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 /// Materialize a single constant operation from a given attribute value with
26 /// the desired resultant type.
28  Attribute value, Type type,
29  Location loc) {
30  if (arith::ConstantOp::isBuildableWith(value, type))
31  return builder.create<arith::ConstantOp>(loc, value, type);
32  if (ConstantOp::isBuildableWith(value, type))
33  return builder.create<ConstantOp>(loc, value, type);
34  return nullptr;
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // CastOp
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if `target` is a ranked tensor type that preserves static
42 /// information available in the `source` ranked tensor type.
44  auto sourceType = source.dyn_cast<RankedTensorType>();
45  auto targetType = target.dyn_cast<RankedTensorType>();
46 
47  // Requires RankedTensorType.
48  if (!sourceType || !targetType)
49  return false;
50 
51  // Requires same elemental type.
52  if (sourceType.getElementType() != targetType.getElementType())
53  return false;
54 
55  // Requires same rank.
56  if (sourceType.getRank() != targetType.getRank())
57  return false;
58 
59  // If cast is towards more static sizes along any dimension, don't fold.
60  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
61  if (!ShapedType::isDynamic(std::get<0>(t)) &&
62  ShapedType::isDynamic(std::get<1>(t)))
63  return false;
64  }
65 
66  return true;
67 }
68 
69 /// Determines whether tensor::CastOp casts to a more dynamic version of the
70 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
71 /// implement canonicalization patterns for ops in different dialects that may
72 /// consume the results of tensor.cast operations. Such foldable tensor.cast
73 /// operations are typically inserted as `slice` ops and are canonicalized,
74 /// to preserve the type compatibility of their uses.
75 ///
76 /// Returns true when all conditions are met:
77 /// 1. source and result are ranked tensors with same element type and rank.
78 /// 2. the tensor type has more static information than the result
79 ///
80 /// Example:
81 /// ```mlir
82 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
83 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
84 /// ```
85 ///
86 /// folds into:
87 ///
88 /// ```mlir
89 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
90 /// ```
92  if (!castOp)
93  return false;
94 
95  // Can fold if the source of cast has at least as much static information as
96  // its results.
97  return preservesStaticInformation(castOp.getType(),
98  castOp.source().getType());
99 }
100 
101 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
102 /// that can be folded.
104  bool folded = false;
105  for (OpOperand &operand : op->getOpOperands()) {
106  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
107  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
108  operand.set(castOp.getOperand());
109  folded = true;
110  }
111  }
112  return success(folded);
113 }
114 
115 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
116  if (inputs.size() != 1 || outputs.size() != 1)
117  return false;
118  Type a = inputs.front(), b = outputs.front();
119  auto aT = a.dyn_cast<TensorType>();
120  auto bT = b.dyn_cast<TensorType>();
121  if (!aT || !bT)
122  return false;
123 
124  if (aT.getElementType() != bT.getElementType())
125  return false;
126 
127  return succeeded(verifyCompatibleShape(aT, bT));
128 }
129 
130 /// Compute a TensorType that has the joined shape knowledge of the two
131 /// given TensorTypes. The element types need to match.
133  assert(one.getElementType() == two.getElementType());
134 
135  if (!one.hasRank())
136  return two;
137  if (!two.hasRank())
138  return one;
139 
140  int64_t rank = one.getRank();
141  if (rank != two.getRank())
142  return {};
143 
144  SmallVector<int64_t, 4> join;
145  join.reserve(rank);
146  for (int64_t i = 0; i < rank; ++i) {
147  if (one.isDynamicDim(i)) {
148  join.push_back(two.getDimSize(i));
149  continue;
150  }
151  if (two.isDynamicDim(i)) {
152  join.push_back(one.getDimSize(i));
153  continue;
154  }
155  if (one.getDimSize(i) != two.getDimSize(i))
156  return {};
157  join.push_back(one.getDimSize(i));
158  }
159  return RankedTensorType::get(join, one.getElementType());
160 }
161 
162 namespace {
163 
164 /// Replaces chains of two tensor.cast operations by a single tensor.cast
165 /// operation if doing so does not remove runtime constraints.
166 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
168 
169  LogicalResult matchAndRewrite(CastOp tensorCast,
170  PatternRewriter &rewriter) const final {
171  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
172 
173  if (!tensorCastOperand)
174  return failure();
175 
176  auto sourceType =
177  tensorCastOperand.getOperand().getType().cast<TensorType>();
178  auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
179  auto resultType = tensorCast.getType().cast<TensorType>();
180 
181  // We can remove the intermediate cast if joining all three produces the
182  // same result as just joining the source and result shapes.
183  auto firstJoin =
184  joinShapes(joinShapes(sourceType, intermediateType), resultType);
185 
186  // The join might not exist if the cast sequence would fail at runtime.
187  if (!firstJoin)
188  return failure();
189 
190  // The newJoin always exists if the above join exists, it might just contain
191  // less information. If so, we cannot drop the intermediate cast, as doing
192  // so would remove runtime checks.
193  auto newJoin = joinShapes(sourceType, resultType);
194  if (firstJoin != newJoin)
195  return failure();
196 
197  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
198  tensorCastOperand.getOperand());
199  return success();
200  }
201 };
202 
203 } // namespace
204 
205 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
206  MLIRContext *context) {
207  results.add<ChainedTensorCast>(context);
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // DimOp
212 //===----------------------------------------------------------------------===//
213 
214 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
215  int64_t index) {
216  auto loc = result.location;
217  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
218  build(builder, result, source, indexValue);
219 }
220 
221 Optional<int64_t> DimOp::getConstantIndex() {
222  if (auto constantOp = index().getDefiningOp<arith::ConstantOp>())
223  return constantOp.getValue().cast<IntegerAttr>().getInt();
224  return {};
225 }
226 
227 static LogicalResult verify(DimOp op) {
228  // Assume unknown index to be in range.
229  Optional<int64_t> index = op.getConstantIndex();
230  if (!index.hasValue())
231  return success();
232 
233  // Check that constant index is not knowingly out of range.
234  auto type = op.source().getType();
235  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
236  if (index.getValue() >= tensorType.getRank())
237  return op.emitOpError("index is out of range");
238  } else if (type.isa<UnrankedTensorType>()) {
239  // Assume index to be in range.
240  } else {
241  llvm_unreachable("expected operand with tensor type");
242  }
243  return success();
244 }
245 
246 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
247  // All forms of folding require a known index.
248  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
249  if (!index)
250  return {};
251 
252  // Folding for unranked types (UnrankedTensorType) is not supported.
253  auto tensorType = source().getType().dyn_cast<RankedTensorType>();
254  if (!tensorType)
255  return {};
256 
257  // Fold if the shape extent along the given index is known.
258  if (!tensorType.isDynamicDim(index.getInt())) {
259  Builder builder(getContext());
260  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
261  }
262 
263  Operation *definingOp = source().getDefiningOp();
264 
265  // Fold dim to the operand of tensor.generate.
266  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
267  auto resultType =
268  fromElements.getResult().getType().cast<RankedTensorType>();
269  // The case where the type encodes the size of the dimension is handled
270  // above.
271  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
272 
273  // Find the operand of the fromElements that corresponds to this index.
274  auto dynExtents = fromElements.dynamicExtents().begin();
275  for (auto dim : resultType.getShape().take_front(index.getInt()))
276  if (ShapedType::isDynamic(dim))
277  dynExtents++;
278 
279  return Value{*dynExtents};
280  }
281 
282  // The size at the given index is now known to be a dynamic size.
283  unsigned unsignedIndex = index.getValue().getZExtValue();
284 
285  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
286  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
287  // `resolve-shaped-type-result-dims` pass.
288  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
289  sliceOp.isDynamicSize(unsignedIndex)) {
290  return {sliceOp.getDynamicSize(unsignedIndex)};
291  }
292  }
293 
294  // dim(cast) -> dim
295  if (succeeded(foldTensorCast(*this)))
296  return getResult();
297 
298  return {};
299 }
300 
301 namespace {
302 /// Fold dim of a cast into the dim of the source of the tensor cast.
303 struct DimOfCastOp : public OpRewritePattern<DimOp> {
305 
306  LogicalResult matchAndRewrite(DimOp dimOp,
307  PatternRewriter &rewriter) const override {
308  auto castOp = dimOp.source().getDefiningOp<CastOp>();
309  if (!castOp)
310  return failure();
311  Value newSource = castOp.getOperand();
312  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
313  return success();
314  }
315 };
316 } // namespace
317 
318 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
319  MLIRContext *context) {
320  results.add<DimOfCastOp>(context);
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // ExtractOp
325 //===----------------------------------------------------------------------===//
326 
327 static LogicalResult verify(ExtractOp op) {
328  // Verify the # indices match if we have a ranked type.
329  if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
330  if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
331  return op.emitOpError("incorrect number of indices for extract_element");
332 
333  return success();
334 }
335 
336 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
337  // The tensor operand must be a known constant.
338  Attribute tensor = operands.front();
339  if (!tensor)
340  return {};
341  // If this is a splat elements attribute, simply return the value. All of the
342  // elements of a splat attribute are the same.
343  if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
344  return splatTensor.getSplatValue<Attribute>();
345 
346  // Otherwise, collect the constant indices into the tensor.
347  SmallVector<uint64_t, 8> indices;
348  for (Attribute indice : llvm::drop_begin(operands, 1)) {
349  if (!indice || !indice.isa<IntegerAttr>())
350  return {};
351  indices.push_back(indice.cast<IntegerAttr>().getInt());
352  }
353 
354  // If this is an elements attribute, query the value at the given indices.
355  auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
356  if (elementsAttr && elementsAttr.isValidIndex(indices))
357  return elementsAttr.getValues<Attribute>()[indices];
358  return {};
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // FromElementsOp
363 //===----------------------------------------------------------------------===//
364 
365 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
366  Type resultType, ValueRange elements) {
367  result.addOperands(elements);
368  result.addTypes(resultType);
369 }
370 
371 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
372  ValueRange elements) {
373  assert(!elements.empty() && "expected at least one element");
374  Type resultType = RankedTensorType::get(
375  {static_cast<int64_t>(elements.size())}, elements.front().getType());
376  build(builder, result, resultType, elements);
377 }
378 
379 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
380  if (!llvm::is_contained(operands, nullptr))
381  return DenseElementsAttr::get(getType(), operands);
382  return {};
383 }
384 
385 namespace {
386 
387 // Canonicalizes the pattern of the form
388 //
389 // %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
390 // %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
391 //
392 // to just %element.
393 struct ExtractElementFromTensorFromElements
394  : public OpRewritePattern<tensor::ExtractOp> {
396 
397  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
398  PatternRewriter &rewriter) const final {
399  auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
400  if (!tensorFromElements)
401  return failure();
402  auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
403  auto rank = tensorType.getRank();
404  if (rank == 0) {
405  rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
406  return success();
407  }
408  SmallVector<APInt, 3> indices(rank);
409  int64_t flatIndex = 0;
410  int64_t stride = 1;
411  for (int i = rank - 1; i >= 0; --i) {
412  APInt index;
413  if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
414  return failure();
415  if (i < rank - 1)
416  stride *= tensorType.getDimSize(i);
417  flatIndex += index.getSExtValue() * stride;
418  }
419  // Prevent out of bounds accesses. This can happen in invalid code that will
420  // never execute.
421  if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
422  return failure();
423  rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
424  return success();
425  }
426 };
427 
428 // Pushes the index_casts that occur before extractions to after the extract.
429 // This minimizes type conversion in some cases and enables the extract
430 // canonicalizer. This changes:
431 //
432 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
433 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
434 //
435 // to the following:
436 //
437 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
438 // %cast = arith.index_cast %extract : i32 to index
439 //
440 // to just %element.
441 //
442 // Consider expanding this to a template and handle all tensor cast operations.
443 struct ExtractElementFromIndexCast
444  : public OpRewritePattern<tensor::ExtractOp> {
446 
447  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
448  PatternRewriter &rewriter) const final {
449  Location loc = extract.getLoc();
450  auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
451  if (!indexCast)
452  return failure();
453 
454  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
455 
456  auto newExtract = rewriter.create<tensor::ExtractOp>(
457  loc, elementTy, indexCast.getIn(), extract.indices());
458 
459  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
460  newExtract);
461 
462  return success();
463  }
464 };
465 
466 } // namespace
467 
468 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
469  MLIRContext *context) {
470  results
471  .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
472  context);
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // InsertOp
477 //===----------------------------------------------------------------------===//
478 
479 static LogicalResult verify(InsertOp op) {
480  // Verify the # indices match if we have a ranked type.
481  if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
482  if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
483  return op.emitOpError("incorrect number of indices");
484  return success();
485 }
486 
487 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
488  Attribute scalar = operands[0];
489  Attribute dest = operands[1];
490  if (scalar && dest)
491  if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
492  if (scalar == splatDest.getSplatValue<Attribute>())
493  return dest;
494  return {};
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // GenerateOp
499 //===----------------------------------------------------------------------===//
500 
501 static LogicalResult verify(GenerateOp op) {
502  // Ensure that the tensor type has as many dynamic dimensions as are specified
503  // by the operands.
504  RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
505  if (op.getNumOperands() != resultTy.getNumDynamicDims())
506  return op.emitError("must have as many index operands as dynamic extents "
507  "in the result type");
508 
509  // Ensure that region arguments span the index space.
510  if (!llvm::all_of(op.body().getArgumentTypes(),
511  [](Type ty) { return ty.isIndex(); }))
512  return op.emitError("all body arguments must be index");
513  if (op.body().getNumArguments() != resultTy.getRank())
514  return op.emitError("must have one body argument per input dimension");
515 
516  // Ensure that the region yields an element of the right type.
517  auto yieldOp =
518  llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
519 
520  if (yieldOp.value().getType() != resultTy.getElementType())
521  return op.emitOpError(
522  "body must be terminated with a `yield` operation of the tensor "
523  "element type");
524 
525  return success();
526 }
527 
528 void GenerateOp::build(
529  OpBuilder &b, OperationState &result, Type resultTy,
530  ValueRange dynamicExtents,
531  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
532  build(b, result, resultTy, dynamicExtents);
533 
534  // Build and populate body.
535  OpBuilder::InsertionGuard guard(b);
536  Region *bodyRegion = result.regions.front().get();
537  auto rank = resultTy.cast<RankedTensorType>().getRank();
538  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
539  SmallVector<Location, 2> argumentLocs(rank, result.location);
540  Block *bodyBlock =
541  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
542  bodyBuilder(b, result.location, bodyBlock->getArguments());
543 }
544 
545 namespace {
546 
547 /// Canonicalizes tensor.generate operations with a constant
548 /// operand into the equivalent operation with the operand expressed in the
549 /// result type, instead. We also insert a type cast to make sure that the
550 /// resulting IR is still well-typed.
551 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
553 
554  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
555  PatternRewriter &rewriter) const final {
556  auto resultType =
557  tensorFromElements.getResult().getType().cast<RankedTensorType>();
558 
559  if (resultType.hasStaticShape())
560  return failure();
561 
562  SmallVector<Value, 4> newOperands;
563  SmallVector<int64_t, 4> newShape;
564  auto operandsIt = tensorFromElements.dynamicExtents().begin();
565 
566  for (int64_t dim : resultType.getShape()) {
567  if (!ShapedType::isDynamic(dim)) {
568  newShape.push_back(dim);
569  continue;
570  }
571  APInt index;
572  if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
573  newShape.push_back(ShapedType::kDynamicSize);
574  newOperands.push_back(*operandsIt++);
575  continue;
576  }
577  newShape.push_back(index.getSExtValue());
578  operandsIt++;
579  }
580 
581  if (newOperands.size() == tensorFromElements.dynamicExtents().size())
582  return failure();
583 
584  auto loc = tensorFromElements.getLoc();
585  auto newOp = rewriter.create<GenerateOp>(
586  loc, RankedTensorType::get(newShape, resultType.getElementType()),
587  newOperands);
588  rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
589  newOp.body().begin());
590  rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
591  newOp);
592  return success();
593  }
594 };
595 
596 /// Canonicalizes the pattern of the form
597 ///
598 /// %tensor = tensor.generate %x {
599 /// ^bb0(%arg0: index):
600 /// <computation>
601 /// yield %1 : index
602 /// } : tensor<?xindex>
603 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
604 ///
605 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
606 /// tensor.generate operation has no side-effects.
607 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
609 
610  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
611  PatternRewriter &rewriter) const final {
612  auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
613  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
614  return failure();
615 
616  BlockAndValueMapping mapping;
617  Block *body = tensorFromElements.getBody();
618  mapping.map(body->getArguments(), extract.indices());
619  for (auto &op : body->without_terminator())
620  rewriter.clone(op, mapping);
621 
622  auto yield = cast<YieldOp>(body->getTerminator());
623 
624  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
625  return success();
626  }
627 };
628 
629 /// Canonicalizes the pattern of the form
630 ///
631 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
632 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
633 ///
634 /// to
635 ///
636 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
637 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
639 
640  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
641  PatternRewriter &rewriter) const final {
642  auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
643  if (!tensorCast)
644  return failure();
645 
646  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
647  extract.indices());
648  return success();
649  }
650 };
651 
652 } // namespace
653 
654 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
655  MLIRContext *context) {
656  // TODO: Move extract patterns to tensor::ExtractOp.
657  results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
658  StaticTensorGenerate>(context);
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // RankOp
663 //===----------------------------------------------------------------------===//
664 
665 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
666  // Constant fold rank when the rank of the operand is known.
667  auto type = getOperand().getType();
668  auto shapedType = type.dyn_cast<ShapedType>();
669  if (shapedType && shapedType.hasRank())
670  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
671  return IntegerAttr();
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // ReshapeOp
676 //===----------------------------------------------------------------------===//
677 
678 static int64_t getNumElements(ShapedType type) {
679  int64_t numElements = 1;
680  for (auto dim : type.getShape())
681  numElements *= dim;
682  return numElements;
683 }
684 
685 static LogicalResult verify(ReshapeOp op) {
686  TensorType operandType = op.source().getType().cast<TensorType>();
687  TensorType resultType = op.result().getType().cast<TensorType>();
688 
689  if (operandType.getElementType() != resultType.getElementType())
690  return op.emitOpError("element types of source and destination tensor "
691  "types should be the same");
692 
693  int64_t shapeSize =
694  op.shape().getType().cast<RankedTensorType>().getDimSize(0);
695  auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
696  auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
697 
698  if (resultRankedType) {
699  if (operandRankedType && resultRankedType.hasStaticShape() &&
700  operandRankedType.hasStaticShape()) {
701  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
702  return op.emitOpError("source and destination tensor should have the "
703  "same number of elements");
704  }
705  if (ShapedType::isDynamic(shapeSize))
706  return op.emitOpError("cannot use shape operand with dynamic length to "
707  "reshape to statically-ranked tensor type");
708  if (shapeSize != resultRankedType.getRank())
709  return op.emitOpError(
710  "length of shape operand differs from the result's tensor rank");
711  }
712  return success();
713 }
714 
715 //===----------------------------------------------------------------------===//
716 // Reassociative reshape ops
717 //===----------------------------------------------------------------------===//
718 
719 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
720  return getSymbolLessAffineMaps(getReassociationExprs());
721 }
722 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
723  return convertReassociationIndicesToExprs(getContext(),
724  getReassociationIndices());
725 }
726 
727 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
728  return getSymbolLessAffineMaps(getReassociationExprs());
729 }
730 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
731  return convertReassociationIndicesToExprs(getContext(),
732  getReassociationIndices());
733 }
734 
735 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
736  ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
737 }
738 
739 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
740  ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
741 }
742 
743 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
744 static RankedTensorType
745 computeTensorReshapeCollapsedType(RankedTensorType type,
746  ArrayRef<AffineMap> reassociation) {
747  auto shape = type.getShape();
748  SmallVector<int64_t, 4> newShape;
749  newShape.reserve(reassociation.size());
750 
751  // Use the fact that reassociation is valid to simplify the logic: only use
752  // each map's rank.
753  assert(isReassociationValid(reassociation) && "invalid reassociation");
754  unsigned currentDim = 0;
755  for (AffineMap m : reassociation) {
756  unsigned dim = m.getNumResults();
757  auto band = shape.slice(currentDim, dim);
758  int64_t size = 1;
759  if (llvm::is_contained(band, ShapedType::kDynamicSize))
760  size = ShapedType::kDynamicSize;
761  else
762  for (unsigned d = 0; d < dim; ++d)
763  size *= shape[currentDim + d];
764  newShape.push_back(size);
765  currentDim += dim;
766  }
767 
768  return RankedTensorType::get(newShape, type.getElementType());
769 }
770 
771 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
772  ArrayRef<ReassociationIndices> reassociation,
773  ArrayRef<NamedAttribute> attrs) {
774  auto resultType = computeTensorReshapeCollapsedType(
775  src.getType().cast<RankedTensorType>(),
777  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
778  build(b, result, resultType, src, attrs);
780  getReassociationIndicesAttribute(b, reassociation));
781 }
782 
783 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
784  ArrayRef<ReassociationIndices> reassociation,
785  ArrayRef<NamedAttribute> attrs) {
786  auto resultType = computeTensorReshapeCollapsedType(
787  src.getType().cast<RankedTensorType>(),
789  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
790  build(b, result, resultType, src, attrs);
792  getReassociationIndicesAttribute(b, reassociation));
793 }
794 
795 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
796  TensorReshapeOp, ExpandShapeOp>::value>
797 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
798  RankedTensorType expandedType,
799  RankedTensorType collapsedType) {
800  if (failed(
801  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
802  return failure();
803 
804  auto maps = op.getReassociationMaps();
805  RankedTensorType expectedType =
806  computeTensorReshapeCollapsedType(expandedType, maps);
807  if (collapsedType != expectedType)
808  return op.emitOpError("expected collapsed type to be ")
809  << expectedType << ", but got " << collapsedType;
810  return success();
811 }
812 
813 static LogicalResult verify(ExpandShapeOp op) {
814  return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
815 }
816 
817 static LogicalResult verify(CollapseShapeOp op) {
818  return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
819 }
820 
821 namespace {
822 /// Reshape of a splat constant can be replaced with a constant of the result
823 /// type.
824 template <typename TensorReshapeOp>
825 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
827  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
828  PatternRewriter &rewriter) const override {
829  DenseElementsAttr attr;
830  if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
831  return failure();
832  if (!attr || !attr.isSplat())
833  return failure();
835  reshapeOp.getResultType(), attr.getRawData(), true);
836  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
837  return success();
838  }
839 };
840 
841 /// Reshape of a FromElements can be replaced with a FromElements of the result
842 /// type
843 template <typename TensorReshapeOp>
844 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
846  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
847  PatternRewriter &rewriter) const override {
848  auto fromElements =
849  reshapeOp.src().template getDefiningOp<FromElementsOp>();
850  if (!fromElements)
851  return failure();
852 
853  auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
854 
855  if (!shapedTy.hasStaticShape())
856  return failure();
857 
858  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
859  fromElements.elements());
860  return success();
861  }
862 };
863 
864 } // namespace
865 
866 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
867  MLIRContext *context) {
870  FoldReshapeWithConstant<ExpandShapeOp>,
871  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
872 }
873 
874 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
875  MLIRContext *context) {
878  FoldReshapeWithConstant<CollapseShapeOp>,
879  FoldReshapeWithFromElements<CollapseShapeOp>>(context);
880 }
881 
882 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
883  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
884 }
885 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
886  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // ExtractSliceOp
891 //===----------------------------------------------------------------------===//
892 
893 /// An extract_slice op result type can be fully inferred from the source type
894 /// and the static representation of offsets, sizes and strides. Special
895 /// sentinels encode the dynamic case.
896 RankedTensorType ExtractSliceOp::inferResultType(
897  RankedTensorType sourceRankedTensorType, ArrayRef<int64_t> staticOffsets,
898  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
899  // An extract_slice op may specify only a leading subset of offset/sizes/
900  // strides in which case we complete with offset=0, sizes from memref type and
901  // strides=1.
902  unsigned rank = sourceRankedTensorType.getRank();
903  (void)rank;
904  assert(staticSizes.size() == rank &&
905  "unexpected staticSizes not equal to rank of source");
906  return RankedTensorType::get(staticSizes,
907  sourceRankedTensorType.getElementType());
908 }
909 
910 RankedTensorType ExtractSliceOp::inferResultType(
911  RankedTensorType sourceRankedTensorType, ArrayRef<OpFoldResult> offsets,
913  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
914  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
915  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
916  ShapedType::kDynamicStrideOrOffset);
917  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
918  ShapedType::kDynamicSize);
919  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
920  ShapedType::kDynamicStrideOrOffset);
921  return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
922  staticSizes, staticStrides);
923 }
924 
925 /// An extract_slice op result type can be fully inferred from the source type
926 /// and the static representation of offsets, sizes and strides. Special
927 /// sentinels encode the dynamic case.
928 RankedTensorType ExtractSliceOp::inferRankReducedResultType(
929  unsigned resultRank, RankedTensorType sourceRankedTensorType,
930  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
931  ArrayRef<int64_t> strides) {
932  auto inferredType =
933  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
934  .cast<RankedTensorType>();
935  int rankDiff = inferredType.getRank() - resultRank;
936  if (rankDiff > 0) {
937  auto shape = inferredType.getShape();
938  llvm::SmallDenseSet<unsigned> dimsToProject;
939  mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
940  SmallVector<int64_t> projectedShape;
941  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
942  if (!dimsToProject.contains(pos))
943  projectedShape.push_back(shape[pos]);
944  inferredType =
945  RankedTensorType::get(projectedShape, inferredType.getElementType());
946  }
947  return inferredType;
948 }
949 
950 RankedTensorType ExtractSliceOp::inferRankReducedResultType(
951  unsigned resultRank, RankedTensorType sourceRankedTensorType,
953  ArrayRef<OpFoldResult> strides) {
954  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
955  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
956  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
957  ShapedType::kDynamicStrideOrOffset);
958  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
959  ShapedType::kDynamicSize);
960  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
961  ShapedType::kDynamicStrideOrOffset);
962  return ExtractSliceOp::inferRankReducedResultType(
963  resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
964  staticStrides);
965 }
966 
967 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
968 /// result type. If the type passed is nullptr, it is inferred.
969 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
970  RankedTensorType resultType, Value source,
971  ArrayRef<OpFoldResult> offsets,
973  ArrayRef<OpFoldResult> strides,
974  ArrayRef<NamedAttribute> attrs) {
975  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
976  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
977  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
978  ShapedType::kDynamicStrideOrOffset);
979  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
980  ShapedType::kDynamicSize);
981  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
982  ShapedType::kDynamicStrideOrOffset);
983  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
984  // Structuring implementation this way avoids duplication between builders.
985  if (!resultType) {
986  resultType =
987  ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
988  staticSizes, staticStrides)
989  .cast<RankedTensorType>();
990  }
991  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
992  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
993  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
994  result.addAttributes(attrs);
995 }
996 
997 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
998 /// result type.
999 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1000  ArrayRef<OpFoldResult> offsets,
1001  ArrayRef<OpFoldResult> sizes,
1002  ArrayRef<OpFoldResult> strides,
1003  ArrayRef<NamedAttribute> attrs) {
1004  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1005 }
1006 
1007 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1008 /// type passed is nullptr, it is inferred.
1009 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1010  RankedTensorType resultType, Value source,
1011  ValueRange offsets, ValueRange sizes,
1012  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1013  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1014  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1015  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1016  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1017  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1018  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1019  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1020 }
1021 
1022 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
1023 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1024  ValueRange offsets, ValueRange sizes,
1025  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1026  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1027 }
1028 
1029 template <typename OpTy>
1031  OpTy op, Type expectedType) {
1032  auto memrefType = expectedType.cast<ShapedType>();
1033  switch (result) {
1035  return success();
1037  return op.emitError("expected rank to be smaller or equal to ")
1038  << "the other rank. ";
1040  return op.emitError("expected type to be ")
1041  << expectedType << " or a rank-reduced version. (size mismatch) ";
1043  return op.emitError("expected element type to be ")
1044  << memrefType.getElementType();
1045  default:
1046  llvm_unreachable("unexpected extract_slice op verification result");
1047  }
1048 }
1049 
1050 /// Verifier for ExtractSliceOp.
1051 static LogicalResult verify(ExtractSliceOp op) {
1052  // Verify result type against inferred type.
1053  auto expectedType =
1054  ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
1055  op.getMixedSizes(), op.getMixedStrides());
1056  auto result =
1057  isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
1058  return produceSliceErrorMsg(result, op, expectedType);
1059 }
1060 
1061 /// Infer the canonical type of the result of an extract_slice op. Returns a
1062 /// type with rank `resultRank` that is either the rank of the rank-reduced
1063 /// type, or the non-rank-reduced type.
1064 static RankedTensorType
1065 getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
1066  ArrayRef<OpFoldResult> mixedOffsets,
1067  ArrayRef<OpFoldResult> mixedSizes,
1068  ArrayRef<OpFoldResult> mixedStrides) {
1069  auto resultType =
1070  ExtractSliceOp::inferRankReducedResultType(
1071  resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
1072  .cast<RankedTensorType>();
1073  if (resultType.getRank() != resultRank) {
1074  resultType = ExtractSliceOp::inferResultType(sourceType, mixedOffsets,
1075  mixedSizes, mixedStrides)
1076  .cast<RankedTensorType>();
1077  }
1078  return resultType;
1079 }
1080 
1081 llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
1082  llvm::SmallDenseSet<unsigned> droppedDims;
1083  ArrayRef<int64_t> resultShape = getType().getShape();
1084  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1085  unsigned shapePos = 0;
1086  for (const auto &size : enumerate(mixedSizes)) {
1087  Optional<int64_t> sizeVal = getConstantIntValue(size.value());
1088  // If the size is not 1, or if the current matched dimension of the result
1089  // is the same static shape as the size value (which is 1), then the
1090  // dimension is preserved.
1091  if (!sizeVal || sizeVal.getValue() != 1 ||
1092  (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1093  shapePos++;
1094  continue;
1095  }
1096  droppedDims.insert(size.index());
1097  }
1098  return droppedDims;
1099 }
1100 
1101 LogicalResult ExtractSliceOp::reifyResultShapes(
1102  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1103  reifiedReturnShapes.resize(1);
1104  reifiedReturnShapes[0].reserve(getType().getRank());
1105  SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
1106  llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
1107  Location loc = getLoc();
1108  for (const auto &size : enumerate(mixedSizes)) {
1109  if (droppedDims.count(size.index()))
1110  continue;
1111  if (auto attr = size.value().dyn_cast<Attribute>()) {
1112  reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
1113  loc, attr.cast<IntegerAttr>().getInt()));
1114  continue;
1115  }
1116  reifiedReturnShapes[0].push_back(size.value().get<Value>());
1117  }
1118  return success();
1119 }
1120 
1121 namespace {
1122 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1123 /// This essentially pushes memref_cast past its consuming slice when
1124 /// `canFoldIntoConsumerOp` is true.
1125 ///
1126 /// Example:
1127 /// ```
1128 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1129 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1130 /// tensor<3x4xf32>
1131 /// ```
1132 /// is rewritten into:
1133 /// ```
1134 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1135 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1136 /// ```
1137 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1138 public:
1140 
1141  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1142  PatternRewriter &rewriter) const override {
1143  // Any constant operand, just return to let SubViewOpConstantFolder kick in.
1144  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1145  return matchPattern(operand, matchConstantIndex());
1146  }))
1147  return failure();
1148 
1149  auto castOp = sliceOp.source().getDefiningOp<tensor::CastOp>();
1150  if (!castOp)
1151  return failure();
1152 
1153  if (!canFoldIntoConsumerOp(castOp))
1154  return failure();
1155 
1156  /// Deduce the type of the result to use for the canonicalized operation.
1157  RankedTensorType resultType = getCanonicalSliceResultType(
1158  sliceOp.getType().getRank(), sliceOp.getSourceType(),
1159  sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1160  sliceOp.getMixedStrides());
1161  Value newSlice = rewriter.create<ExtractSliceOp>(
1162  sliceOp.getLoc(), resultType, castOp.source(), sliceOp.offsets(),
1163  sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
1164  sliceOp.static_sizes(), sliceOp.static_strides());
1165  rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1166  newSlice);
1167  return success();
1168  }
1169 };
1170 } // namespace
1171 
1172 /// Return the canonical type of the result of an extract_slice op.
1174  RankedTensorType operator()(ExtractSliceOp op,
1175  ArrayRef<OpFoldResult> mixedOffsets,
1176  ArrayRef<OpFoldResult> mixedSizes,
1177  ArrayRef<OpFoldResult> mixedStrides) {
1178  return getCanonicalSliceResultType(op.getType().getRank(),
1179  op.getSourceType(), mixedOffsets,
1180  mixedSizes, mixedStrides);
1181  }
1182 };
1183 
1184 /// A canonicalizer wrapper to replace ExtractSliceOps.
1186  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1187  ExtractSliceOp newOp) {
1188  Value replacement = newOp.getResult();
1189  if (replacement.getType() != op.getType())
1190  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1191  replacement);
1192  rewriter.replaceOp(op, replacement);
1193  }
1194 };
1195 
1196 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1197  MLIRContext *context) {
1198  results.add<
1201  ExtractSliceOpCastFolder>(context);
1202 }
1203 
1204 //
1205 static LogicalResult
1206 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1207  ShapedType shapedType) {
1208  OpBuilder b(op.getContext());
1209  for (OpFoldResult ofr : op.getMixedOffsets())
1210  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1211  return failure();
1212  // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1213  // is appropriate.
1214  auto shape = shapedType.getShape();
1215  for (auto it : llvm::zip(op.getMixedSizes(), shape))
1216  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1217  return failure();
1218  for (OpFoldResult ofr : op.getMixedStrides())
1219  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1220  return failure();
1221  return success();
1222 }
1223 
1224 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1225 /// we can return the InsertSliceOp's source directly.
1226 // TODO: This only checks the immediate producer; extend to go up the
1227 // insert/extract chain if the slices are disjoint.
1228 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
1229  auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
1230 
1231  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1232  if (insertOp && insertOp.source().getType() == extractOp.getType() &&
1233  insertOp.isSameAs(extractOp, isSame))
1234  return insertOp.source();
1235 
1236  return {};
1237 }
1238 
1239 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
1240  if (getSourceType() == getType() &&
1242  return this->source();
1243  if (Value slice = foldExtractAfterInsertSlice(*this))
1244  return slice;
1245  return OpFoldResult();
1246 }
1247 
1249  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1250  auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1251  unsigned rank = rankedTensorType.getRank();
1252  auto shape = rankedTensorType.getShape();
1253  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1255  for (unsigned i = 0, e = rank; i < e; ++i) {
1256  OpFoldResult dim;
1257  if (rankedTensorType.isDynamicDim(i))
1258  dim = b.createOrFold<tensor::DimOp>(
1259  loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1260  else
1261  dim = b.getIndexAttr(shape[i]);
1262  sizes.push_back(dim);
1263  }
1264  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1265  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1266  offsets, sizes, strides);
1267 }
1268 
1269 //===----------------------------------------------------------------------===//
1270 // InsertSliceOp
1271 //===----------------------------------------------------------------------===//
1272 
1273 // Build a InsertSliceOp with mixed static and dynamic entries.
1274 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1275  Value dest, ArrayRef<OpFoldResult> offsets,
1276  ArrayRef<OpFoldResult> sizes,
1277  ArrayRef<OpFoldResult> strides,
1278  ArrayRef<NamedAttribute> attrs) {
1279  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1280  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1281  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1282  ShapedType::kDynamicStrideOrOffset);
1283  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1284  ShapedType::kDynamicSize);
1285  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1286  ShapedType::kDynamicStrideOrOffset);
1287  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1288  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1289  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1290  result.addAttributes(attrs);
1291 }
1292 
1293 // Build a InsertSliceOp with dynamic entries.
1294 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1295  Value dest, ValueRange offsets, ValueRange sizes,
1296  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1297  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1298  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1299  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1300  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1301  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1302  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1303  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1304 }
1305 
1306 /// Verifier for InsertSliceOp.
1307 static LogicalResult verify(InsertSliceOp op) {
1308  // insert_slice is the inverse of extract_slice, use the same type inference.
1309  auto expectedType = ExtractSliceOp::inferRankReducedResultType(
1310  op.getSourceType().getRank(), op.getType(),
1311  extractFromI64ArrayAttr(op.static_offsets()),
1312  extractFromI64ArrayAttr(op.static_sizes()),
1313  extractFromI64ArrayAttr(op.static_strides()));
1314  auto result =
1315  isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
1316  return produceSliceErrorMsg(result, op, expectedType);
1317 }
1318 
1319 /// If we have two consecutive InsertSliceOp writing to the same slice, we
1320 /// can mutate the second InsertSliceOp's destination to the first one's.
1321 ///
1322 /// Example:
1323 ///
1324 /// ```mlir
1325 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1326 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1327 /// ```
1328 ///
1329 /// folds into:
1330 ///
1331 /// ```mlir
1332 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1333 /// ```
1334 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
1335  auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
1336 
1337  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1338  if (!prevInsertOp ||
1339  prevInsertOp.source().getType() != insertOp.source().getType() ||
1340  !prevInsertOp.isSameAs(insertOp, isSame))
1341  return failure();
1342 
1343  insertOp.destMutable().assign(prevInsertOp.dest());
1344  return success();
1345 }
1346 
1347 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1348  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1349  getSourceType() == getType() &&
1351  return this->source();
1353  return getResult();
1354  return OpFoldResult();
1355 }
1356 
1357 LogicalResult InsertSliceOp::reifyResultShapes(
1358  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1359  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1360  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1361  reifiedReturnShapes[0][dim] =
1362  builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
1363  }
1364  return success();
1365 }
1366 
1367 namespace {
1368 /// Pattern to rewrite a insert_slice op with constant arguments.
1369 class InsertSliceOpConstantArgumentFolder final
1370  : public OpRewritePattern<InsertSliceOp> {
1371 public:
1373 
1374  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1375  PatternRewriter &rewriter) const override {
1376  // No constant operand, just return.
1377  if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1378  return matchPattern(operand, matchConstantIndex());
1379  }))
1380  return failure();
1381 
1382  // At least one of offsets/sizes/strides is a new constant.
1383  // Form the new list of operands and constant attributes from the
1384  // existing.
1385  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1386  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1387  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1388  canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1389  canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1390  canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1391 
1392  // Create the new op in canonical form.
1393  auto sourceType = ExtractSliceOp::inferRankReducedResultType(
1394  insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
1395  mixedOffsets, mixedSizes, mixedStrides);
1396  Value toInsert = insertSliceOp.source();
1397  if (sourceType != insertSliceOp.getSourceType())
1398  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1399  sourceType, toInsert);
1400  rewriter.replaceOpWithNewOp<InsertSliceOp>(
1401  insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
1402  mixedStrides);
1403  return success();
1404  }
1405 };
1406 
1407 /// Fold tensor_casts with insert_slice operations. If the source or destination
1408 /// tensor is a tensor_cast that removes static type information, the cast is
1409 /// folded into the insert_slice operation. E.g.:
1410 ///
1411 /// ```mlir
1412 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1413 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1414 /// ```
1415 ///
1416 /// folds into:
1417 ///
1418 /// ```mlir
1419 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1420 /// ```
1421 ///
1422 /// Note: When folding a cast on the destination tensor, the result of the
1423 /// insert_slice operation is casted to ensure that the type of the result did
1424 /// not change.
1425 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1427 
1428  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1429  PatternRewriter &rewriter) const override {
1430  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1431  return matchPattern(operand, matchConstantIndex());
1432  }))
1433  return failure();
1434 
1435  auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1436  auto castOp = v.getDefiningOp<tensor::CastOp>();
1437  if (!castOp || !canFoldIntoConsumerOp(castOp))
1438  return llvm::None;
1439  return castOp.source();
1440  };
1441  Optional<Value> sourceCastSource =
1442  getSourceOfCastOp(insertSliceOp.source());
1443  Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.dest());
1444  if (!sourceCastSource && !destCastSource)
1445  return failure();
1446 
1447  Value replacement = rewriter.create<InsertSliceOp>(
1448  insertSliceOp.getLoc(),
1449  (sourceCastSource ? *sourceCastSource : insertSliceOp.source()),
1450  (destCastSource ? *destCastSource : insertSliceOp.dest()),
1451  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1452  insertSliceOp.getMixedStrides());
1453 
1454  if (replacement.getType() != insertSliceOp.getType()) {
1455  replacement = rewriter.create<tensor::CastOp>(
1456  insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1457  }
1458  rewriter.replaceOp(insertSliceOp, replacement);
1459  return success();
1460  }
1461 };
1462 
1463 /// If additional static type information can be deduced from a insert_slice's
1464 /// size operands, insert an explicit cast of the op's source operand. This
1465 /// enables other canonicalization patterns that are matching for tensor_cast
1466 /// ops such as `ForOpTensorCastFolder` in SCF.
1467 ///
1468 /// Example:
1469 ///
1470 /// ```mlir
1471 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1472 /// : tensor<?x?xf32> into ...
1473 /// ```
1474 ///
1475 /// folds into:
1476 ///
1477 /// ```mlir
1478 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1479 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1480 /// : tensor<64x64xf32> into ...
1481 /// ```
1482 struct InsertSliceOpSourceCastInserter final
1483  : public OpRewritePattern<InsertSliceOp> {
1485 
1486  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1487  PatternRewriter &rewriter) const override {
1488  RankedTensorType srcType = insertSliceOp.getSourceType();
1489  if (srcType.getRank() != insertSliceOp.getType().getRank())
1490  return failure();
1491  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1492  srcType.getShape().end());
1493  for (int64_t i = 0; i < srcType.getRank(); ++i) {
1494  if (Optional<int64_t> constInt =
1495  getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1496  newSrcShape[i] = *constInt;
1497  }
1498 
1499  RankedTensorType newSrcType =
1500  RankedTensorType::get(newSrcShape, srcType.getElementType());
1501  if (srcType == newSrcType ||
1502  !preservesStaticInformation(srcType, newSrcType) ||
1503  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1504  return failure();
1505 
1506  // newSrcType is:
1507  // 1) Different from srcType.
1508  // 2) "More static" than srcType.
1509  // 3) Cast-compatible with srcType.
1510  // Insert the cast.
1511  Value cast = rewriter.create<tensor::CastOp>(
1512  insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1513  rewriter.replaceOpWithNewOp<InsertSliceOp>(
1514  insertSliceOp, cast, insertSliceOp.dest(),
1515  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1516  insertSliceOp.getMixedStrides());
1517  return success();
1518  }
1519 };
1520 } // namespace
1521 
1522 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1523  MLIRContext *context) {
1524  results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1525  InsertSliceOpSourceCastInserter>(context);
1526 }
1527 
1529  Location loc,
1530  Value tensor,
1531  Value dest) {
1532  auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1533  unsigned rank = rankedTensorType.getRank();
1534  auto shape = rankedTensorType.getShape();
1535  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1537  for (unsigned i = 0, e = rank; i < e; ++i) {
1538  OpFoldResult dim;
1539  if (rankedTensorType.isDynamicDim(i))
1540  dim = b.createOrFold<tensor::DimOp>(
1541  loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1542  else
1543  dim = b.getIndexAttr(shape[i]);
1544  sizes.push_back(dim);
1545  }
1546  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1547  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1548  sizes, strides);
1549 }
1550 
1551 //===----------------------------------------------------------------------===//
1552 // PadOp
1553 //===----------------------------------------------------------------------===//
1554 
1555 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1556 // supports optional types.
1557 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1558  Type typeToInfer, Type typeToInferFrom) {}
1559 
1562  Type &typeToInfer, Type typeToInferFrom) {
1563  if (optOperand)
1564  typeToInfer = typeToInferFrom;
1565  return success();
1566 }
1567 
1568 static LogicalResult verify(PadOp op) {
1569  auto sourceType = op.source().getType().cast<RankedTensorType>();
1570  auto resultType = op.result().getType().cast<RankedTensorType>();
1571  auto expectedType = PadOp::inferResultType(
1572  sourceType, extractFromI64ArrayAttr(op.static_low()),
1573  extractFromI64ArrayAttr(op.static_high()));
1574  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1575  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1576  continue;
1577  if (expectedType.isDynamicDim(i))
1578  continue;
1579  return op.emitError("specified type ")
1580  << resultType << " does not match the inferred type "
1581  << expectedType;
1582  }
1583 
1584  auto &region = op.region();
1585  unsigned rank = resultType.getRank();
1586  Block &block = region.front();
1587  if (block.getNumArguments() != rank)
1588  return op.emitError("expected the block to have ") << rank << " arguments";
1589 
1590  // Note: the number and type of yield values are checked in the YieldOp.
1591  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1592  if (!en.value().isIndex())
1593  return op.emitOpError("expected block argument ")
1594  << (en.index() + 1) << " to be an index";
1595  }
1596 
1597  // Ensure that the region yields an element of the right type.
1598  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
1599  if (yieldOp.value().getType() !=
1600  op.getType().cast<ShapedType>().getElementType())
1601  return op.emitOpError("expected yield type to match shape element type");
1602 
1603  return success();
1604 }
1605 
1606 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1607  ArrayRef<int64_t> staticLow,
1608  ArrayRef<int64_t> staticHigh,
1609  ArrayRef<int64_t> resultShape) {
1610  unsigned rank = sourceType.getRank();
1611  assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1612  assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1613  assert((resultShape.empty() || resultShape.size() == rank) &&
1614  "unexpected resultShape size mismatch");
1615 
1616  SmallVector<int64_t, 4> inferredShape;
1617  for (auto i : llvm::seq<unsigned>(0, rank)) {
1618  if (sourceType.isDynamicDim(i) ||
1619  staticLow[i] == ShapedType::kDynamicSize ||
1620  staticHigh[i] == ShapedType::kDynamicSize) {
1621  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1622  : resultShape[i]);
1623  } else {
1624  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1625  assert((resultShape.empty() || size == resultShape[i] ||
1626  resultShape[i] == ShapedType::kDynamicSize) &&
1627  "mismatch between inferred shape and result shape");
1628  inferredShape.push_back(size);
1629  }
1630  }
1631 
1632  return RankedTensorType::get(inferredShape, sourceType.getElementType());
1633 }
1634 
1635 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1636  ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1637  ValueRange low, ValueRange high, bool nofold,
1638  ArrayRef<NamedAttribute> attrs) {
1639  auto sourceType = source.getType().cast<RankedTensorType>();
1640  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1641  build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1642  b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1643  result.addAttributes(attrs);
1644 }
1645 
1646 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1647  ValueRange low, ValueRange high, bool nofold,
1648  ArrayRef<NamedAttribute> attrs) {
1649  auto sourceType = source.getType().cast<RankedTensorType>();
1650  unsigned rank = sourceType.getRank();
1651  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1652  build(b, result, source, staticVector, staticVector, low, high, nofold,
1653  attrs);
1654 }
1655 
1656 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1657  Value source, ArrayRef<OpFoldResult> low,
1658  ArrayRef<OpFoldResult> high, bool nofold,
1659  ArrayRef<NamedAttribute> attrs) {
1660  assert(resultType.isa<RankedTensorType>());
1661  auto sourceType = source.getType().cast<RankedTensorType>();
1662  SmallVector<Value, 4> dynamicLow, dynamicHigh;
1663  SmallVector<int64_t, 4> staticLow, staticHigh;
1664  // staticLow and staticHigh have full information of the padding config.
1665  // This will grow staticLow and staticHigh with 1 value. If the config is
1666  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1667  // value as well.
1668  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1669  ShapedType::kDynamicSize);
1670  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1671  ShapedType::kDynamicSize);
1672  if (!resultType) {
1673  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1674  }
1675  build(b, result, resultType, source, dynamicLow, dynamicHigh,
1676  b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1677  nofold ? b.getUnitAttr() : UnitAttr());
1678  result.addAttributes(attrs);
1679 }
1680 
1681 namespace {
1682 // Folds tensor.pad when padding is static zeros and the attribute
1683 // doesn't request otherwise.
1684 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1686 
1687  LogicalResult matchAndRewrite(PadOp padTensorOp,
1688  PatternRewriter &rewriter) const override {
1689  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1690  return failure();
1691  if (padTensorOp.nofold())
1692  return failure();
1693  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1694  padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
1695  return success();
1696  }
1697 };
1698 
1699 // Fold CastOp into PadOp when adding static information.
1700 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1702 
1703  LogicalResult matchAndRewrite(PadOp padTensorOp,
1704  PatternRewriter &rewriter) const override {
1705  auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
1706  if (!tensor::canFoldIntoConsumerOp(castOp))
1707  return failure();
1708 
1709  auto newResultType = PadOp::inferResultType(
1710  castOp.source().getType().cast<RankedTensorType>(),
1711  extractFromI64ArrayAttr(padTensorOp.static_low()),
1712  extractFromI64ArrayAttr(padTensorOp.static_high()),
1713  padTensorOp.getResultType().getShape());
1714 
1715  if (newResultType == padTensorOp.getResultType()) {
1716  rewriter.updateRootInPlace(padTensorOp, [&]() {
1717  padTensorOp.sourceMutable().assign(castOp.source());
1718  });
1719  } else {
1720  auto newOp = rewriter.create<PadOp>(
1721  padTensorOp->getLoc(), newResultType, padTensorOp.source(),
1722  padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
1723  padTensorOp.static_high(), padTensorOp.nofold());
1724  BlockAndValueMapping mapper;
1725  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1726 
1727  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1728  padTensorOp, padTensorOp.getResultType(), newOp);
1729  }
1730  return success();
1731  }
1732 };
1733 
1734 // Fold CastOp using the result of PadOp back into the latter if it adds
1735 // static information.
1736 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
1738 
1739  LogicalResult matchAndRewrite(PadOp padTensorOp,
1740  PatternRewriter &rewriter) const override {
1741  if (!padTensorOp.result().hasOneUse())
1742  return failure();
1743  auto tensorCastOp =
1744  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
1745  if (!tensorCastOp)
1746  return failure();
1747  if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
1748  tensorCastOp.dest().getType()))
1749  return failure();
1750 
1751  auto replacementOp = rewriter.create<PadOp>(
1752  padTensorOp.getLoc(), tensorCastOp.dest().getType(),
1753  padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
1754  padTensorOp.static_low(), padTensorOp.static_high(),
1755  padTensorOp.nofold());
1756  replacementOp.region().takeBody(padTensorOp.region());
1757 
1758  rewriter.replaceOp(padTensorOp, replacementOp.result());
1759  rewriter.replaceOp(tensorCastOp, replacementOp.result());
1760  return success();
1761  }
1762 };
1763 } // namespace
1764 
1765 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1766  MLIRContext *context) {
1767  results
1768  .add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast>(
1769  context);
1770 }
1771 
1772 /// Return the padding value of the PadOp if it constant. In this context,
1773 /// "constant" means an actual constant or "defined outside of the block".
1774 ///
1775 /// Values are considered constant in three cases:
1776 /// - A ConstantLike value.
1777 /// - A basic block argument from a different block.
1778 /// - A value defined outside of the block.
1779 ///
1780 /// If the padding value is not constant, an empty Value is returned.
1781 Value PadOp::getConstantPaddingValue() {
1782  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
1783  if (!yieldOp)
1784  return {};
1785  Value padValue = yieldOp.value();
1786  // Check if yield value is a constant.
1787  if (matchPattern(padValue, m_Constant()))
1788  return padValue;
1789  // Check if yield value is defined inside the PadOp block.
1790  if (padValue.getParentBlock() == &getRegion().front())
1791  return {};
1792  // Else: Yield value defined outside of the PadOp block.
1793  return padValue;
1794 }
1795 
1796 OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
1797  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
1798  !nofold())
1799  return source();
1800  return {};
1801 }
1802 
1803 //===----------------------------------------------------------------------===//
1804 // TableGen'd op method definitions
1805 //===----------------------------------------------------------------------===//
1806 
1807 #define GET_OP_CLASSES
1808 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Include the generated interface declarations.
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:40
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:444
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:282
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:797
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
void getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape, llvm::SmallDenseSet< unsigned > &dimsToProject)
Definition: Utils.cpp:42
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:132
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:252
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< OpFoldResult, 4 > getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes)
Return a vector of all the static or dynamic sizes of the op from provided external static and dynami...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
An attribute that represents a reference to a dense vector or tensor object.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded...
Definition: TensorOps.cpp:103
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void addOperands(ValueRange newOperands)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:1334
U dyn_cast() const
Definition: Types.h:244
unsigned getNumArguments()
Definition: Block.h:119
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:43
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:334
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
static RankedTensorType getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Infer the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:1065
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:1248
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
BlockArgListType getArguments()
Definition: Block.h:76
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:1557
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
auto getType() const
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:91
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:1174
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:1185
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: TensorOps.cpp:1030
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
static void print(OpAsmPrinter &p, ExpandShapeOp op)
Definition: TensorOps.cpp:735
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:1173
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:1228
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
iterator end()
Definition: Region.h:56
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
U dyn_cast() const
Definition: Attributes.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:678
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
ParseResult parseInferType(OpAsmParser &parser, Optional< OpAsmParser::OperandType > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:1560
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef< AffineMap > reassociation)
Compute the RankedTensorType obtained by applying reassociation to type.
Definition: TensorOps.cpp:745
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
Type getElementType() const
Returns the element type of this tensor type.
This class represents an operand of an operation.
Definition: Value.h:249
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:1186
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:1206
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
bool isa() const
Definition: Types.h:234
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:1528
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool isSplatBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
U cast() const
Definition: Types.h:250
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:23