MLIR  15.0.0git
TensorOps.cpp
Go to the documentation of this file.
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 
23 using namespace mlir;
24 using namespace mlir::tensor;
25 
26 /// Materialize a single constant operation from a given attribute value with
27 /// the desired resultant type.
29  Attribute value, Type type,
30  Location loc) {
31  if (arith::ConstantOp::isBuildableWith(value, type))
32  return builder.create<arith::ConstantOp>(loc, value, type);
33  if (complex::ConstantOp::isBuildableWith(value, type))
34  return builder.create<complex::ConstantOp>(loc, type,
35  value.cast<ArrayAttr>());
36  return nullptr;
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // CastOp
41 //===----------------------------------------------------------------------===//
42 
43 /// Returns true if `target` is a ranked tensor type that preserves static
44 /// information available in the `source` ranked tensor type.
46  auto sourceType = source.dyn_cast<RankedTensorType>();
47  auto targetType = target.dyn_cast<RankedTensorType>();
48 
49  // Requires RankedTensorType.
50  if (!sourceType || !targetType)
51  return false;
52 
53  // Requires same elemental type.
54  if (sourceType.getElementType() != targetType.getElementType())
55  return false;
56 
57  // Requires same rank.
58  if (sourceType.getRank() != targetType.getRank())
59  return false;
60 
61  // If cast is towards more static sizes along any dimension, don't fold.
62  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
63  if (!ShapedType::isDynamic(std::get<0>(t)) &&
64  ShapedType::isDynamic(std::get<1>(t)))
65  return false;
66  }
67 
68  return true;
69 }
70 
71 /// Determines whether tensor::CastOp casts to a more dynamic version of the
72 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
73 /// implement canonicalization patterns for ops in different dialects that may
74 /// consume the results of tensor.cast operations. Such foldable tensor.cast
75 /// operations are typically inserted as `slice` ops and are canonicalized,
76 /// to preserve the type compatibility of their uses.
77 ///
78 /// Returns true when all conditions are met:
79 /// 1. source and result are ranked tensors with same element type and rank.
80 /// 2. the tensor type has more static information than the result
81 ///
82 /// Example:
83 /// ```mlir
84 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
85 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
86 /// ```
87 ///
88 /// folds into:
89 ///
90 /// ```mlir
91 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
92 /// ```
94  if (!castOp)
95  return false;
96 
97  // Can fold if the source of cast has at least as much static information as
98  // its results.
99  return preservesStaticInformation(castOp.getType(),
100  castOp.getSource().getType());
101 }
102 
103 /// Determines whether the tensor::CastOp casts to a more static version of the
104 /// source tensor. This is useful to fold into a producing op and implement
105 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
106 /// being from different dialects. Returns true when all conditions are met:
107 /// 1. source and result and ranked tensors with same element type and rank.
108 /// 2. the result type has more static information than the source.
109 ///
110 /// Example:
111 /// ```mlir
112 /// %1 = producer ... : tensor<?x?xf32>
113 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
114 /// ```
115 ///
116 /// can be canonicalized to :
117 ///
118 /// ```mlir
119 /// %2 = producer ... : tensor<8x16xf32>
120 /// ```
121 /// Not all ops might be canonicalizable this way, but for those that can be,
122 /// this method provides a check that it is worth doing the canonicalization.
124  if (!castOp)
125  return false;
126  return preservesStaticInformation(castOp.getSource().getType(),
127  castOp.getType());
128 }
129 
130 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
131 /// that can be folded.
133  bool folded = false;
134  for (OpOperand &operand : op->getOpOperands()) {
135  auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
136  if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
137  operand.set(castOp.getOperand());
138  folded = true;
139  }
140  }
141  return success(folded);
142 }
143 
144 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
145  if (inputs.size() != 1 || outputs.size() != 1)
146  return false;
147  Type a = inputs.front(), b = outputs.front();
148  auto aT = a.dyn_cast<TensorType>();
149  auto bT = b.dyn_cast<TensorType>();
150  if (!aT || !bT)
151  return false;
152 
153  if (aT.getElementType() != bT.getElementType())
154  return false;
155 
156  return succeeded(verifyCompatibleShape(aT, bT));
157 }
158 
159 /// Compute a TensorType that has the joined shape knowledge of the two
160 /// given TensorTypes. The element types need to match.
162  assert(one.getElementType() == two.getElementType());
163 
164  if (!one.hasRank())
165  return two;
166  if (!two.hasRank())
167  return one;
168 
169  int64_t rank = one.getRank();
170  if (rank != two.getRank())
171  return {};
172 
174  join.reserve(rank);
175  for (int64_t i = 0; i < rank; ++i) {
176  if (one.isDynamicDim(i)) {
177  join.push_back(two.getDimSize(i));
178  continue;
179  }
180  if (two.isDynamicDim(i)) {
181  join.push_back(one.getDimSize(i));
182  continue;
183  }
184  if (one.getDimSize(i) != two.getDimSize(i))
185  return {};
186  join.push_back(one.getDimSize(i));
187  }
188  return RankedTensorType::get(join, one.getElementType());
189 }
190 
191 namespace {
192 
193 /// Replaces chains of two tensor.cast operations by a single tensor.cast
194 /// operation if doing so does not remove runtime constraints.
195 struct ChainedTensorCast : public OpRewritePattern<CastOp> {
197 
198  LogicalResult matchAndRewrite(CastOp tensorCast,
199  PatternRewriter &rewriter) const final {
200  auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
201 
202  if (!tensorCastOperand)
203  return failure();
204 
205  auto sourceType =
206  tensorCastOperand.getOperand().getType().cast<TensorType>();
207  auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
208  auto resultType = tensorCast.getType().cast<TensorType>();
209 
210  // We can remove the intermediate cast if joining all three produces the
211  // same result as just joining the source and result shapes.
212  auto firstJoin =
213  joinShapes(joinShapes(sourceType, intermediateType), resultType);
214 
215  // The join might not exist if the cast sequence would fail at runtime.
216  if (!firstJoin)
217  return failure();
218 
219  // The newJoin always exists if the above join exists, it might just contain
220  // less information. If so, we cannot drop the intermediate cast, as doing
221  // so would remove runtime checks.
222  auto newJoin = joinShapes(sourceType, resultType);
223  if (firstJoin != newJoin)
224  return failure();
225 
226  rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
227  tensorCastOperand.getOperand());
228  return success();
229  }
230 };
231 
232 /// Fold tensor.cast into tesor.extract_slice producer.
233 /// Example:
234 /// ```
235 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
236 /// tensor<128x512xf32> to tensor<?x512xf32>
237 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
238 /// ```
239 /// ->
240 /// ```
241 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
242 /// tensor<128x512xf32> to tensor<16x512xf32>
243 /// ```
244 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
246 
247  LogicalResult matchAndRewrite(CastOp tensorCast,
248  PatternRewriter &rewriter) const final {
249  auto extractOperand =
250  tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
251 
252  if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
253  tensorCast.getType().getShape() == tensorCast.getSource()
254  .getType()
255  .cast<RankedTensorType>()
256  .getShape())
257  return failure();
258 
259  SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
260  auto dimMask = computeRankReductionMask(
261  extractFromI64ArrayAttr(extractOperand.getStaticSizes()),
262  extractOperand.getType().getShape());
263  size_t dimIndex = 0;
264  for (size_t i = 0, e = sizes.size(); i < e; i++) {
265  if (dimMask && dimMask->count(i))
266  continue;
267  int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268  if (ShapedType::isDynamic(dim))
269  continue;
270  sizes[i] = rewriter.getIndexAttr(dim);
271  }
272 
273  rewriter.replaceOpWithNewOp<ExtractSliceOp>(
274  tensorCast, tensorCast.getType().cast<RankedTensorType>(),
275  extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276  extractOperand.getMixedStrides());
277  return success();
278  }
279 };
280 
281 } // namespace
282 
283 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
284  MLIRContext *context) {
285  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // DimOp
290 //===----------------------------------------------------------------------===//
291 
292 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
293  int64_t index) {
294  auto loc = result.location;
295  Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
296  build(builder, result, source, indexValue);
297 }
298 
299 Optional<int64_t> DimOp::getConstantIndex() {
300  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301  return constantOp.getValue().cast<IntegerAttr>().getInt();
302  return {};
303 }
304 
306  // Assume unknown index to be in range.
307  Optional<int64_t> index = getConstantIndex();
308  if (!index)
309  return success();
310 
311  // Check that constant index is not knowingly out of range.
312  auto type = getSource().getType();
313  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
314  if (*index >= tensorType.getRank())
315  return emitOpError("index is out of range");
316  } else if (type.isa<UnrankedTensorType>()) {
317  // Assume index to be in range.
318  } else {
319  llvm_unreachable("expected operand with tensor type");
320  }
321  return success();
322 }
323 
324 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
325  // All forms of folding require a known index.
326  auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
327  if (!index)
328  return {};
329 
330  // Folding for unranked types (UnrankedTensorType) is not supported.
331  auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
332  if (!tensorType)
333  return {};
334 
335  // Fold if the shape extent along the given index is known.
336  if (!tensorType.isDynamicDim(index.getInt())) {
337  Builder builder(getContext());
338  return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
339  }
340 
341  Operation *definingOp = getSource().getDefiningOp();
342 
343  // Fold dim to the operand of tensor.generate.
344  if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
345  auto resultType =
346  fromElements.getResult().getType().cast<RankedTensorType>();
347  // The case where the type encodes the size of the dimension is handled
348  // above.
349  assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
350 
351  // Find the operand of the fromElements that corresponds to this index.
352  auto dynExtents = fromElements.getDynamicExtents().begin();
353  for (auto dim : resultType.getShape().take_front(index.getInt()))
354  if (ShapedType::isDynamic(dim))
355  dynExtents++;
356 
357  return Value{*dynExtents};
358  }
359 
360  // The size at the given index is now known to be a dynamic size.
361  unsigned unsignedIndex = index.getValue().getZExtValue();
362 
363  if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
364  // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
365  // `resolve-shaped-type-result-dims` pass.
366  if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
367  sliceOp.isDynamicSize(unsignedIndex)) {
368  return {sliceOp.getDynamicSize(unsignedIndex)};
369  }
370  }
371 
372  // dim(cast) -> dim
373  if (succeeded(foldTensorCast(*this)))
374  return getResult();
375 
376  return {};
377 }
378 
379 namespace {
380 /// Fold dim of a cast into the dim of the source of the tensor cast.
381 struct DimOfCastOp : public OpRewritePattern<DimOp> {
383 
384  LogicalResult matchAndRewrite(DimOp dimOp,
385  PatternRewriter &rewriter) const override {
386  auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
387  if (!castOp)
388  return failure();
389  Value newSource = castOp.getOperand();
390  rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
391  return success();
392  }
393 };
394 } // namespace
395 
396 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
397  MLIRContext *context) {
398  results.add<DimOfCastOp>(context);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // ExtractOp
403 //===----------------------------------------------------------------------===//
404 
406  // Verify the # indices match if we have a ranked type.
407  if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
408  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
409  return emitOpError("incorrect number of indices for extract_element");
410 
411  return success();
412 }
413 
414 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
415  // If this is a splat elements attribute, simply return the value. All of the
416  // elements of a splat attribute are the same.
417  if (Attribute tensor = operands.front())
418  if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
419  return splatTensor.getSplatValue<Attribute>();
420 
421  // Collect the constant indices into the tensor.
422  SmallVector<uint64_t, 8> indices;
423  for (Attribute indice : llvm::drop_begin(operands, 1)) {
424  if (!indice || !indice.isa<IntegerAttr>())
425  return {};
426  indices.push_back(indice.cast<IntegerAttr>().getInt());
427  }
428 
429  // Fold extract(from_elements(...)).
430  if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
431  auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
432  auto rank = tensorType.getRank();
433  assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
434  "rank mismatch");
435  int flatIndex = 0;
436  int stride = 1;
437  for (int i = rank - 1; i >= 0; --i) {
438  if (i < rank - 1)
439  stride *= tensorType.getDimSize(i);
440  flatIndex += indices[i] * stride;
441  }
442  // Prevent out of bounds accesses. This can happen in invalid code that will
443  // never execute.
444  if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
445  flatIndex < 0)
446  return {};
447  return fromElementsOp.getElements()[flatIndex];
448  }
449 
450  // If this is an elements attribute, query the value at the given indices.
451  if (Attribute tensor = operands.front()) {
452  auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453  if (elementsAttr && elementsAttr.isValidIndex(indices))
454  return elementsAttr.getValues<Attribute>()[indices];
455  }
456 
457  return {};
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // FromElementsOp
462 //===----------------------------------------------------------------------===//
463 
464 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
465  Type resultType, ValueRange elements) {
466  result.addOperands(elements);
467  result.addTypes(resultType);
468 }
469 
470 void FromElementsOp::build(OpBuilder &builder, OperationState &result,
471  ValueRange elements) {
472  assert(!elements.empty() && "expected at least one element");
473  Type resultType = RankedTensorType::get(
474  {static_cast<int64_t>(elements.size())}, elements.front().getType());
475  build(builder, result, resultType, elements);
476 }
477 
478 OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
479  if (!llvm::is_contained(operands, nullptr))
480  return DenseElementsAttr::get(getType(), operands);
481  return {};
482 }
483 
484 namespace {
485 
486 // Pushes the index_casts that occur before extractions to after the extract.
487 // This minimizes type conversion in some cases and enables the extract
488 // canonicalizer. This changes:
489 //
490 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
491 // %extract = tensor.extract %cast[%index] : tensor<1xindex>
492 //
493 // to the following:
494 //
495 // %extract = tensor.extract %tensor[%index] : tensor<1xindex>
496 // %cast = arith.index_cast %extract : i32 to index
497 //
498 // to just %element.
499 //
500 // Consider expanding this to a template and handle all tensor cast operations.
501 struct ExtractElementFromIndexCast
502  : public OpRewritePattern<tensor::ExtractOp> {
504 
505  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
506  PatternRewriter &rewriter) const final {
507  Location loc = extract.getLoc();
508  auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
509  if (!indexCast)
510  return failure();
511 
512  Type elementTy = getElementTypeOrSelf(indexCast.getIn());
513 
514  auto newExtract = rewriter.create<tensor::ExtractOp>(
515  loc, elementTy, indexCast.getIn(), extract.getIndices());
516 
517  rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
518  newExtract);
519 
520  return success();
521  }
522 };
523 
524 } // namespace
525 
526 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
527  MLIRContext *context) {
528  results.add<ExtractElementFromIndexCast>(context);
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // InsertOp
533 //===----------------------------------------------------------------------===//
534 
536  // Verify the # indices match if we have a ranked type.
537  if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
538  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
539  return emitOpError("incorrect number of indices");
540  return success();
541 }
542 
543 OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
544  Attribute scalar = operands[0];
545  Attribute dest = operands[1];
546  if (scalar && dest)
547  if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
548  if (scalar == splatDest.getSplatValue<Attribute>())
549  return dest;
550  return {};
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // GenerateOp
555 //===----------------------------------------------------------------------===//
556 
557 LogicalResult GenerateOp::reifyResultShapes(
558  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
559  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
560  int idx = 0;
561  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562  if (getType().isDynamicDim(dim)) {
563  reifiedReturnShapes[0][dim] = getOperand(idx++);
564  } else {
565  reifiedReturnShapes[0][dim] = builder.create<arith::ConstantIndexOp>(
566  getLoc(), getType().getDimSize(dim));
567  }
568  }
569  return success();
570 }
571 
573  // Ensure that the tensor type has as many dynamic dimensions as are specified
574  // by the operands.
575  RankedTensorType resultTy = getType().cast<RankedTensorType>();
576  if (getNumOperands() != resultTy.getNumDynamicDims())
577  return emitError("must have as many index operands as dynamic extents "
578  "in the result type");
579 
580  return success();
581 }
582 
583 LogicalResult GenerateOp::verifyRegions() {
584  RankedTensorType resultTy = getType().cast<RankedTensorType>();
585  // Ensure that region arguments span the index space.
586  if (!llvm::all_of(getBody().getArgumentTypes(),
587  [](Type ty) { return ty.isIndex(); }))
588  return emitError("all body arguments must be index");
589  if (getBody().getNumArguments() != resultTy.getRank())
590  return emitError("must have one body argument per input dimension");
591 
592  // Ensure that the region yields an element of the right type.
593  auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
594 
595  if (yieldOp.getValue().getType() != resultTy.getElementType())
596  return emitOpError(
597  "body must be terminated with a `yield` operation of the tensor "
598  "element type");
599 
600  return success();
601 }
602 
603 void GenerateOp::build(
604  OpBuilder &b, OperationState &result, Type resultTy,
605  ValueRange dynamicExtents,
606  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
607  build(b, result, resultTy, dynamicExtents);
608 
609  // Build and populate body.
610  OpBuilder::InsertionGuard guard(b);
611  Region *bodyRegion = result.regions.front().get();
612  auto rank = resultTy.cast<RankedTensorType>().getRank();
613  SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
614  SmallVector<Location, 2> argumentLocs(rank, result.location);
615  Block *bodyBlock =
616  b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
617  bodyBuilder(b, result.location, bodyBlock->getArguments());
618 }
619 
620 namespace {
621 
622 /// Canonicalizes tensor.generate operations with a constant
623 /// operand into the equivalent operation with the operand expressed in the
624 /// result type, instead. We also insert a type cast to make sure that the
625 /// resulting IR is still well-typed.
626 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
628 
629  LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
630  PatternRewriter &rewriter) const final {
631  auto resultType =
632  tensorFromElements.getResult().getType().cast<RankedTensorType>();
633 
634  if (resultType.hasStaticShape())
635  return failure();
636 
637  SmallVector<Value, 4> newOperands;
638  SmallVector<int64_t, 4> newShape;
639  auto operandsIt = tensorFromElements.getDynamicExtents().begin();
640 
641  for (int64_t dim : resultType.getShape()) {
642  if (!ShapedType::isDynamic(dim)) {
643  newShape.push_back(dim);
644  continue;
645  }
646  APInt index;
647  if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
648  newShape.push_back(ShapedType::kDynamicSize);
649  newOperands.push_back(*operandsIt++);
650  continue;
651  }
652  newShape.push_back(index.getSExtValue());
653  operandsIt++;
654  }
655 
656  if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
657  return failure();
658 
659  auto loc = tensorFromElements.getLoc();
660  auto newOp = rewriter.create<GenerateOp>(
661  loc, RankedTensorType::get(newShape, resultType.getElementType()),
662  newOperands);
663  rewriter.inlineRegionBefore(tensorFromElements.getBody(), newOp.getBody(),
664  newOp.getBody().begin());
665  rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
666  newOp);
667  return success();
668  }
669 };
670 
671 /// Canonicalizes the pattern of the form
672 ///
673 /// %tensor = tensor.generate %x {
674 /// ^bb0(%arg0: index):
675 /// <computation>
676 /// yield %1 : index
677 /// } : tensor<?xindex>
678 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
679 ///
680 /// to just <computation> with %arg0 replaced by %c0. We only do this if the
681 /// tensor.generate operation has no side-effects.
682 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
684 
685  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
686  PatternRewriter &rewriter) const final {
687  auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
688  if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
689  return failure();
690 
691  BlockAndValueMapping mapping;
692  Block *body = &tensorFromElements.getBody().front();
693  mapping.map(body->getArguments(), extract.getIndices());
694  for (auto &op : body->without_terminator())
695  rewriter.clone(op, mapping);
696 
697  auto yield = cast<YieldOp>(body->getTerminator());
698 
699  rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
700  return success();
701  }
702 };
703 
704 /// Canonicalizes the pattern of the form
705 ///
706 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
707 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
708 ///
709 /// to
710 ///
711 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
712 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
714 
715  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
716  PatternRewriter &rewriter) const final {
717  auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
718  if (!tensorCast)
719  return failure();
720 
721  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
722  extract, tensorCast.getSource(), extract.getIndices());
723  return success();
724  }
725 };
726 
727 } // namespace
728 
729 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
730  MLIRContext *context) {
731  // TODO: Move extract patterns to tensor::ExtractOp.
732  results.add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733  StaticTensorGenerate>(context);
734 }
735 
736 //===----------------------------------------------------------------------===//
737 // RankOp
738 //===----------------------------------------------------------------------===//
739 
740 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
741  // Constant fold rank when the rank of the operand is known.
742  auto type = getOperand().getType();
743  auto shapedType = type.dyn_cast<ShapedType>();
744  if (shapedType && shapedType.hasRank())
745  return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
746  return IntegerAttr();
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // ReshapeOp
751 //===----------------------------------------------------------------------===//
752 
753 static int64_t getNumElements(ShapedType type) {
754  int64_t numElements = 1;
755  for (auto dim : type.getShape())
756  numElements *= dim;
757  return numElements;
758 }
759 
761  TensorType operandType = getSource().getType().cast<TensorType>();
762  TensorType resultType = getResult().getType().cast<TensorType>();
763 
764  if (operandType.getElementType() != resultType.getElementType())
765  return emitOpError("element types of source and destination tensor "
766  "types should be the same");
767 
768  int64_t shapeSize =
769  getShape().getType().cast<RankedTensorType>().getDimSize(0);
770  auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
771  auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
772 
773  if (resultRankedType) {
774  if (operandRankedType && resultRankedType.hasStaticShape() &&
775  operandRankedType.hasStaticShape()) {
776  if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
777  return emitOpError("source and destination tensor should have the "
778  "same number of elements");
779  }
780  if (ShapedType::isDynamic(shapeSize))
781  return emitOpError("cannot use shape operand with dynamic length to "
782  "reshape to statically-ranked tensor type");
783  if (shapeSize != resultRankedType.getRank())
784  return emitOpError(
785  "length of shape operand differs from the result's tensor rank");
786  }
787  return success();
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // Reassociative reshape ops
792 //===----------------------------------------------------------------------===//
793 
794 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
795  return getSymbolLessAffineMaps(getReassociationExprs());
796 }
797 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
798  return convertReassociationIndicesToExprs(getContext(),
799  getReassociationIndices());
800 }
801 
802 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
803  return getSymbolLessAffineMaps(getReassociationExprs());
804 }
805 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
806  return convertReassociationIndicesToExprs(getContext(),
807  getReassociationIndices());
808 }
809 
810 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
811 static RankedTensorType
812 computeTensorReshapeCollapsedType(RankedTensorType type,
813  ArrayRef<AffineMap> reassociation) {
814  auto shape = type.getShape();
815  SmallVector<int64_t, 4> newShape;
816  newShape.reserve(reassociation.size());
817 
818  // Use the fact that reassociation is valid to simplify the logic: only use
819  // each map's rank.
820  assert(isReassociationValid(reassociation) && "invalid reassociation");
821  unsigned currentDim = 0;
822  for (AffineMap m : reassociation) {
823  unsigned dim = m.getNumResults();
824  auto band = shape.slice(currentDim, dim);
825  int64_t size = 1;
826  if (llvm::is_contained(band, ShapedType::kDynamicSize))
827  size = ShapedType::kDynamicSize;
828  else
829  for (unsigned d = 0; d < dim; ++d)
830  size *= shape[currentDim + d];
831  newShape.push_back(size);
832  currentDim += dim;
833  }
834 
835  return RankedTensorType::get(newShape, type.getElementType());
836 }
837 
838 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
839  ArrayRef<ReassociationIndices> reassociation,
840  ArrayRef<NamedAttribute> attrs) {
841  auto resultType = computeTensorReshapeCollapsedType(
842  src.getType().cast<RankedTensorType>(),
844  convertReassociationIndicesToExprs(b.getContext(), reassociation)));
845  build(b, result, resultType, src, attrs);
846  result.addAttribute(getReassociationAttrStrName(),
847  getReassociationIndicesAttribute(b, reassociation));
848 }
849 
850 // Checks if types are the same, but ignoring encoding on ranked tensors.
851 static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
852  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
853  if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
854  return rtp1.getShape() == rtp2.getShape() &&
855  rtp1.getElementType() == rtp2.getElementType();
856  return false;
857  }
858  // Default implementation.
859  return tp1 == tp2;
860 }
861 
862 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
863  TensorReshapeOp, ExpandShapeOp>::value>
864 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
865  RankedTensorType expandedType,
866  RankedTensorType collapsedType) {
867  if (failed(
868  verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
869  return failure();
870 
871  auto maps = op.getReassociationMaps();
872  RankedTensorType expectedType =
873  computeTensorReshapeCollapsedType(expandedType, maps);
874  if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
875  return op.emitOpError("expected collapsed type to be ")
876  << expectedType << ", but got " << collapsedType;
877  return success();
878 }
879 
881  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
882 }
883 
885  return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
886 }
887 
888 namespace {
889 /// Reshape of a splat constant can be replaced with a constant of the result
890 /// type.
891 template <typename TensorReshapeOp>
892 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
894  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
895  PatternRewriter &rewriter) const override {
896  DenseElementsAttr attr;
897  if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
898  return failure();
899  if (!attr || !attr.isSplat())
900  return failure();
902  reshapeOp.getResultType(), attr.getRawData());
903  rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
904  return success();
905  }
906 };
907 
908 /// Reshape of a FromElements can be replaced with a FromElements of the result
909 /// type
910 template <typename TensorReshapeOp>
911 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
913  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
914  PatternRewriter &rewriter) const override {
915  auto fromElements =
916  reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
917  if (!fromElements)
918  return failure();
919 
920  auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
921 
922  if (!shapedTy.hasStaticShape())
923  return failure();
924 
925  rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
926  fromElements.getElements());
927  return success();
928  }
929 };
930 
931 } // namespace
932 
933 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
934  MLIRContext *context) {
937  FoldReshapeWithConstant<ExpandShapeOp>,
938  FoldReshapeWithFromElements<ExpandShapeOp>>(context);
939 }
940 
941 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
942  MLIRContext *context) {
945  FoldReshapeWithConstant<CollapseShapeOp>,
946  FoldReshapeWithFromElements<CollapseShapeOp>>(context);
947 }
948 
949 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
950  return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
951 }
952 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
953  return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // ExtractSliceOp
958 //===----------------------------------------------------------------------===//
959 
960 /// An extract_slice result type can be inferred, when it is not
961 /// rank-reduced, from the source type and the static representation of
962 /// offsets, sizes and strides. Special sentinels encode the dynamic case.
963 RankedTensorType ExtractSliceOp::inferResultType(
964  ShapedType sourceShapedTensorType, ArrayRef<int64_t> staticOffsets,
965  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
966  // An extract_slice op may specify only a leading subset of offset/sizes/
967  // strides in which case we complete with offset=0, sizes from memref type and
968  // strides=1.
969  assert(static_cast<int64_t>(staticSizes.size()) ==
970  sourceShapedTensorType.getRank() &&
971  "unexpected staticSizes not equal to rank of source");
972  return RankedTensorType::get(staticSizes,
973  sourceShapedTensorType.getElementType());
974 }
975 
976 RankedTensorType ExtractSliceOp::inferResultType(
977  ShapedType sourceShapedTensorType, ArrayRef<OpFoldResult> offsets,
979  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
980  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
981  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
982  ShapedType::kDynamicStrideOrOffset);
983  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
984  ShapedType::kDynamicSize);
985  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
986  ShapedType::kDynamicStrideOrOffset);
987  return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988  staticSizes, staticStrides);
989 }
990 
991 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
992 /// number of sizes), drop as many size 1 as needed to produce an inferred type
993 /// with the desired rank.
994 ///
995 /// Note that there may be multiple ways to compute this rank-reduced type:
996 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
997 ///
998 /// To disambiguate, this function always drops the first 1 sizes occurrences.
999 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1001  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1002  ArrayRef<int64_t> strides) {
1003  // Type inferred in the absence of rank-reducing behavior.
1004  auto inferredType =
1005  inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006  .cast<RankedTensorType>();
1007  int rankDiff = inferredType.getRank() - desiredResultRank;
1008  if (rankDiff > 0) {
1009  auto shape = inferredType.getShape();
1010  llvm::SmallBitVector dimsToProject =
1011  getPositionsOfShapeOne(rankDiff, shape);
1012  SmallVector<int64_t> projectedShape;
1013  // Best effort rank-reducing: drop 1s in order.
1014  for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1015  if (!dimsToProject.test(pos))
1016  projectedShape.push_back(shape[pos]);
1017  inferredType =
1018  RankedTensorType::get(projectedShape, inferredType.getElementType());
1019  }
1020  return inferredType;
1021 }
1022 
1023 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024  unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1026  ArrayRef<OpFoldResult> strides) {
1027  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1028  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1029  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1030  ShapedType::kDynamicStrideOrOffset);
1031  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1032  ShapedType::kDynamicSize);
1033  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1034  ShapedType::kDynamicStrideOrOffset);
1035  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036  desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1037  staticStrides);
1038 }
1039 
1040 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom
1041 /// result type. If the type passed is nullptr, it is inferred.
1042 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1043  RankedTensorType resultType, Value source,
1044  ArrayRef<OpFoldResult> offsets,
1045  ArrayRef<OpFoldResult> sizes,
1046  ArrayRef<OpFoldResult> strides,
1047  ArrayRef<NamedAttribute> attrs) {
1048  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1049  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1050  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1051  ShapedType::kDynamicStrideOrOffset);
1052  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1053  ShapedType::kDynamicSize);
1054  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1055  ShapedType::kDynamicStrideOrOffset);
1056  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
1057  // Structuring implementation this way avoids duplication between builders.
1058  if (!resultType) {
1059  resultType =
1060  ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061  staticSizes, staticStrides)
1062  .cast<RankedTensorType>();
1063  }
1064  build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1065  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1066  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1067  result.addAttributes(attrs);
1068 }
1069 
1070 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
1071 /// result type.
1072 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1073  ArrayRef<OpFoldResult> offsets,
1074  ArrayRef<OpFoldResult> sizes,
1075  ArrayRef<OpFoldResult> strides,
1076  ArrayRef<NamedAttribute> attrs) {
1077  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1078 }
1079 
1080 /// Build an ExtractSliceOp with dynamic entries and custom result type. If the
1081 /// type passed is nullptr, it is inferred.
1082 void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
1083  RankedTensorType resultType, Value source,
1084  ValueRange offsets, ValueRange sizes,
1085  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1086  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1087  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1088  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1089  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1090  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1091  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1092  build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1093 }
1094 
1095 /// Build an ExtractSliceOp with dynamic entries and inferred result type.
1096 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1097  ValueRange offsets, ValueRange sizes,
1098  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1099  build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1100 }
1101 
1102 template <typename OpTy>
1104  OpTy op, Type expectedType) {
1105  auto memrefType = expectedType.cast<ShapedType>();
1106  switch (result) {
1108  return success();
1110  return op.emitError("expected rank to be smaller or equal to ")
1111  << "the other rank. ";
1113  return op.emitError("expected type to be ")
1114  << expectedType << " or a rank-reduced version. (size mismatch) ";
1116  return op.emitError("expected element type to be ")
1117  << memrefType.getElementType();
1118  default:
1119  llvm_unreachable("unexpected extract_slice op verification result");
1120  }
1121 }
1122 
1123 /// Verifier for ExtractSliceOp.
1125  // Verify result type against inferred type.
1126  auto expectedType = ExtractSliceOp::inferResultType(
1127  getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
1128  auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
1129  return produceSliceErrorMsg(result, *this, expectedType);
1130 }
1131 
1132 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1133  ArrayRef<int64_t> resultShape = getType().getShape();
1135  llvm::SmallBitVector droppedDims(mixedSizes.size());
1136  unsigned shapePos = 0;
1137  for (const auto &size : enumerate(mixedSizes)) {
1138  Optional<int64_t> sizeVal = getConstantIntValue(size.value());
1139  // If the size is not 1, or if the current matched dimension of the result
1140  // is the same static shape as the size value (which is 1), then the
1141  // dimension is preserved.
1142  if (!sizeVal || *sizeVal != 1 ||
1143  (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1144  shapePos++;
1145  continue;
1146  }
1147  droppedDims.set(size.index());
1148  }
1149  return droppedDims;
1150 }
1151 
1152 LogicalResult ExtractSliceOp::reifyResultShapes(
1153  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1154  reifiedReturnShapes.resize(1);
1155  reifiedReturnShapes[0].reserve(getType().getRank());
1157  llvm::SmallBitVector droppedDims = getDroppedDims();
1158  Location loc = getLoc();
1159  for (const auto &size : enumerate(mixedSizes)) {
1160  if (droppedDims.test(size.index()))
1161  continue;
1162  if (auto attr = size.value().dyn_cast<Attribute>()) {
1163  reifiedReturnShapes[0].push_back(builder.create<arith::ConstantIndexOp>(
1164  loc, attr.cast<IntegerAttr>().getInt()));
1165  continue;
1166  }
1167  reifiedReturnShapes[0].push_back(size.value().get<Value>());
1168  }
1169  return success();
1170 }
1171 
1172 namespace {
1173 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
1174 /// This essentially pushes memref_cast past its consuming slice when
1175 /// `canFoldIntoConsumerOp` is true.
1176 ///
1177 /// Example:
1178 /// ```
1179 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
1180 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
1181 /// tensor<3x4xf32>
1182 /// ```
1183 /// is rewritten into:
1184 /// ```
1185 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
1186 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
1187 /// ```
1188 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
1189 public:
1191 
1192  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
1193  PatternRewriter &rewriter) const override {
1194  // Any constant operand, just return to let the constant folder kick in.
1195  if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
1196  return matchPattern(operand, matchConstantIndex());
1197  }))
1198  return failure();
1199 
1200  auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1201  if (!castOp)
1202  return failure();
1203 
1204  if (!canFoldIntoConsumerOp(castOp))
1205  return failure();
1206 
1207  /// Deduce the type of the result to use for the canonicalized operation.
1208  RankedTensorType resultType =
1209  ExtractSliceOp::inferCanonicalRankReducedResultType(
1210  sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211  sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212  sliceOp.getMixedStrides());
1213  Value newSlice = rewriter.create<ExtractSliceOp>(
1214  sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
1215  sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1216  sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1217  rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
1218  newSlice);
1219  return success();
1220  }
1221 };
1222 
1223 /// Slice elements from `values` into `outValues`. `counts` represents the
1224 /// numbers of elements to stride in the original values for each dimension.
1225 /// The output values can be used to construct a DenseElementsAttr.
1226 template <typename IterTy, typename ElemTy>
1227 static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
1228  ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1229  ArrayRef<int64_t> strides,
1230  llvm::SmallVectorImpl<ElemTy> *outValues) {
1231  assert(offsets.size() == sizes.size());
1232  assert(offsets.size() == strides.size());
1233  if (offsets.empty())
1234  return;
1235 
1236  int64_t offset = offsets.front();
1237  int64_t size = sizes.front();
1238  int64_t stride = strides.front();
1239  if (offsets.size() == 1) {
1240  for (int64_t i = 0; i < size; ++i, offset += stride)
1241  outValues->push_back(*(values + offset));
1242 
1243  return;
1244  }
1245 
1246  for (int64_t i = 0; i < size; ++i, offset += stride) {
1247  auto begin = values + offset * counts.front();
1248  sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1249  offsets.drop_front(), sizes.drop_front(),
1250  strides.drop_front(), outValues);
1251  }
1252 }
1253 
1254 /// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
1255 /// operation might introduce more constant data; Users can control their
1256 /// heuristics by the control function.
1257 class ConstantOpExtractSliceFolder final
1258  : public OpRewritePattern<ExtractSliceOp> {
1259 public:
1261 
1262  ConstantOpExtractSliceFolder(MLIRContext *context,
1265  controlFn(std::move(controlFn)) {}
1266 
1267  LogicalResult matchAndRewrite(ExtractSliceOp op,
1268  PatternRewriter &rewriter) const override {
1269  DenseElementsAttr attr;
1270  if (!matchPattern(op.getSource(), m_Constant(&attr)))
1271  return failure();
1272 
1273  // A constant splat is handled by fold().
1274  if (attr.isSplat())
1275  return failure();
1276 
1277  // Dynamic result shape is not supported.
1278  auto sourceType = op.getSource().getType().cast<ShapedType>();
1279  auto resultType = op.getResult().getType().cast<ShapedType>();
1280  if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1281  return failure();
1282 
1283  // Customized control over the folding.
1284  if (!controlFn(op))
1285  return failure();
1286 
1287  int64_t count = sourceType.getNumElements();
1288  if (count == 0)
1289  return failure();
1290 
1291  // Check if there are any dynamic parts, which are not supported.
1292  auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets());
1293  if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
1294  return failure();
1295  auto sizes = extractFromI64ArrayAttr(op.getStaticSizes());
1296  if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
1297  return failure();
1298  auto strides = extractFromI64ArrayAttr(op.getStaticStrides());
1299  if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
1300  return failure();
1301 
1302  // Compute the stride for each dimension.
1303  SmallVector<int64_t> counts;
1304  ArrayRef<int64_t> shape = sourceType.getShape();
1305  counts.reserve(shape.size());
1306  for (int64_t v : shape) {
1307  count = count / v;
1308  counts.push_back(count);
1309  }
1310 
1311  // New attribute constructed by the sliced values.
1312  DenseElementsAttr newAttr;
1313 
1314  if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
1315  SmallVector<APInt> outValues;
1316  outValues.reserve(sourceType.getNumElements());
1317  sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1318  elems.begin(), counts, offsets, sizes, strides, &outValues);
1319  newAttr = DenseElementsAttr::get(resultType, outValues);
1320  } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
1321  SmallVector<APFloat> outValues;
1322  outValues.reserve(sourceType.getNumElements());
1323  sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1324  elems.begin(), counts, offsets, sizes, strides, &outValues);
1325  newAttr = DenseElementsAttr::get(resultType, outValues);
1326  }
1327 
1328  if (newAttr) {
1329  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
1330  return success();
1331  }
1332 
1333  return failure();
1334  }
1335 
1336 private:
1337  /// This additionally controls whether the fold happens or not. Users can
1338  /// impose their heuristics in the function.
1340 };
1341 
1342 } // namespace
1343 
1345  RewritePatternSet &patterns,
1346  const ControlConstantExtractSliceFusionFn &controlFn) {
1347  patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
1348 }
1349 
1350 /// Return the canonical type of the result of an extract_slice op.
1352  RankedTensorType operator()(ExtractSliceOp op,
1353  ArrayRef<OpFoldResult> mixedOffsets,
1354  ArrayRef<OpFoldResult> mixedSizes,
1355  ArrayRef<OpFoldResult> mixedStrides) {
1356  return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357  op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1358  mixedStrides);
1359  }
1360 };
1361 
1362 /// A canonicalizer wrapper to replace ExtractSliceOps.
1364  void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
1365  ExtractSliceOp newOp) {
1366  Value replacement = newOp.getResult();
1367  if (replacement.getType() != op.getType())
1368  replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
1369  replacement);
1370  rewriter.replaceOp(op, replacement);
1371  }
1372 };
1373 
1374 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1375  MLIRContext *context) {
1376  results.add<
1379  ExtractSliceOpCastFolder>(context);
1380 }
1381 
1382 //
1383 static LogicalResult
1384 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
1385  ShapedType shapedType) {
1386  OpBuilder b(op.getContext());
1387  for (OpFoldResult ofr : op.getMixedOffsets())
1388  if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
1389  return failure();
1390  // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
1391  // is appropriate.
1392  auto shape = shapedType.getShape();
1393  for (auto it : llvm::zip(op.getMixedSizes(), shape))
1394  if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
1395  return failure();
1396  for (OpFoldResult ofr : op.getMixedStrides())
1397  if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
1398  return failure();
1399  return success();
1400 }
1401 
1402 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
1403 /// we can return the InsertSliceOp's source directly.
1404 // TODO: This only checks the immediate producer; extend to go up the
1405 // insert/extract chain if the slices are disjoint.
1406 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
1407  auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
1408 
1409  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1410  if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411  insertOp.isSameAs(extractOp, isSame))
1412  return insertOp.getSource();
1413 
1414  return {};
1415 }
1416 
1417 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
1418  if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
1419  auto resultType = getResult().getType().cast<ShapedType>();
1420  if (resultType.hasStaticShape())
1421  return splat.resizeSplat(resultType);
1422  }
1423  if (getSourceType() == getType() &&
1425  return this->getSource();
1426  if (Value slice = foldExtractAfterInsertSlice(*this))
1427  return slice;
1428 
1429  return OpFoldResult();
1430 }
1431 
1433  OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
1434  auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
1435  unsigned rank = rankedTensorType.getRank();
1436  auto shape = rankedTensorType.getShape();
1437  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1439  for (unsigned i = 0, e = rank; i < e; ++i) {
1440  OpFoldResult dim;
1441  if (rankedTensorType.isDynamicDim(i))
1442  dim = b.createOrFold<tensor::DimOp>(
1443  loc, tensor, b.create<arith::ConstantIndexOp>(loc, i));
1444  else
1445  dim = b.getIndexAttr(shape[i]);
1446  sizes.push_back(dim);
1447  }
1448  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1449  return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450  offsets, sizes, strides);
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // InsertSliceOp
1455 //===----------------------------------------------------------------------===//
1456 
1457 // Build a InsertSliceOp with mixed static and dynamic entries.
1458 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1459  Value dest, ArrayRef<OpFoldResult> offsets,
1460  ArrayRef<OpFoldResult> sizes,
1461  ArrayRef<OpFoldResult> strides,
1462  ArrayRef<NamedAttribute> attrs) {
1463  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1464  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1465  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1466  ShapedType::kDynamicStrideOrOffset);
1467  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1468  ShapedType::kDynamicSize);
1469  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1470  ShapedType::kDynamicStrideOrOffset);
1471  build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
1472  dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1473  b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1474  result.addAttributes(attrs);
1475 }
1476 
1477 // Build a InsertSliceOp with dynamic entries.
1478 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
1479  Value dest, ValueRange offsets, ValueRange sizes,
1480  ValueRange strides, ArrayRef<NamedAttribute> attrs) {
1481  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1482  llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1483  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1484  llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1485  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1486  llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1487  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1488 }
1489 
1491 verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
1492  ArrayAttr staticOffsets, ArrayAttr staticSizes,
1493  ArrayAttr staticStrides,
1494  ShapedType *expectedType = nullptr) {
1495  // insert_slice is the inverse of extract_slice, use the same type inference.
1496  auto expected = ExtractSliceOp::inferResultType(
1497  dstType, extractFromI64ArrayAttr(staticOffsets),
1498  extractFromI64ArrayAttr(staticSizes),
1499  extractFromI64ArrayAttr(staticStrides))
1500  .cast<ShapedType>();
1501  if (expectedType)
1502  *expectedType = expected;
1503  return isRankReducedType(expected, srcType);
1504 }
1505 
1506 /// Verifier for InsertSliceOp.
1508  ShapedType expectedType;
1509  auto result =
1510  verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
1511  getStaticSizes(), getStaticStrides(), &expectedType);
1512  return produceSliceErrorMsg(result, *this, expectedType);
1513 }
1514 
1515 /// If we have two consecutive InsertSliceOp writing to the same slice, we
1516 /// can mutate the second InsertSliceOp's destination to the first one's.
1517 ///
1518 /// Example:
1519 ///
1520 /// ```mlir
1521 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
1522 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
1523 /// ```
1524 ///
1525 /// folds into:
1526 ///
1527 /// ```mlir
1528 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
1529 /// ```
1530 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
1531  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
1532 
1533  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
1534  if (!prevInsertOp ||
1535  prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1536  !prevInsertOp.isSameAs(insertOp, isSame))
1537  return failure();
1538 
1539  insertOp.getDestMutable().assign(prevInsertOp.getDest());
1540  return success();
1541 }
1542 
1543 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
1544  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1545  getSourceType() == getType() &&
1547  return this->getSource();
1549  return getResult();
1550  return OpFoldResult();
1551 }
1552 
1553 LogicalResult InsertSliceOp::reifyResultShapes(
1554  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1555  reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
1556  for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1557  reifiedReturnShapes[0][dim] =
1558  builder.createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1559  }
1560  return success();
1561 }
1562 
1563 namespace {
1564 /// Pattern to rewrite a insert_slice op with constant arguments.
1565 class InsertSliceOpConstantArgumentFolder final
1566  : public OpRewritePattern<InsertSliceOp> {
1567 public:
1569 
1570  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1571  PatternRewriter &rewriter) const override {
1572  // No constant operand, just return.
1573  if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
1574  return matchPattern(operand, matchConstantIndex());
1575  }))
1576  return failure();
1577 
1578  // At least one of offsets/sizes/strides is a new constant.
1579  // Form the new list of operands and constant attributes from the
1580  // existing.
1581  SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
1582  SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
1583  SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
1584  canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
1585  canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
1586  canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
1587 
1588  // Create the new op in canonical form.
1589  auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1590  insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
1591  mixedOffsets, mixedSizes, mixedStrides);
1592  Value toInsert = insertSliceOp.getSource();
1593  if (sourceType != insertSliceOp.getSourceType())
1594  toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
1595  sourceType, toInsert);
1596  rewriter.replaceOpWithNewOp<InsertSliceOp>(
1597  insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
1598  mixedSizes, mixedStrides);
1599  return success();
1600  }
1601 };
1602 
1603 /// Fold tensor_casts with insert_slice operations. If the source or destination
1604 /// tensor is a tensor_cast that removes static type information, the cast is
1605 /// folded into the insert_slice operation. E.g.:
1606 ///
1607 /// ```mlir
1608 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1609 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1610 /// ```
1611 ///
1612 /// folds into:
1613 ///
1614 /// ```mlir
1615 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1616 /// ```
1617 ///
1618 /// Note: When folding a cast on the destination tensor, the result of the
1619 /// insert_slice operation is casted to ensure that the type of the result did
1620 /// not change.
1621 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1623 
1624  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1625  PatternRewriter &rewriter) const override {
1626  if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
1627  return matchPattern(operand, matchConstantIndex());
1628  }))
1629  return failure();
1630 
1631  auto getSourceOfCastOp = [](Value v) -> Optional<Value> {
1632  auto castOp = v.getDefiningOp<tensor::CastOp>();
1633  if (!castOp || !canFoldIntoConsumerOp(castOp))
1634  return llvm::None;
1635  return castOp.getSource();
1636  };
1637  Optional<Value> sourceCastSource =
1638  getSourceOfCastOp(insertSliceOp.getSource());
1639  Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1640  if (!sourceCastSource && !destCastSource)
1641  return failure();
1642 
1643  auto src =
1644  (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
1645  auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1646 
1647  auto srcType = src.getType().cast<ShapedType>();
1648  auto dstType = dst.getType().cast<ShapedType>();
1649  if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
1650  insertSliceOp.getStaticSizes(),
1651  insertSliceOp.getStaticStrides()) !=
1653  return failure();
1654 
1655  Value replacement = rewriter.create<InsertSliceOp>(
1656  insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1657  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1658 
1659  if (replacement.getType() != insertSliceOp.getType()) {
1660  replacement = rewriter.create<tensor::CastOp>(
1661  insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1662  }
1663  rewriter.replaceOp(insertSliceOp, replacement);
1664  return success();
1665  }
1666 };
1667 
1668 /// If additional static type information can be deduced from a insert_slice's
1669 /// size operands, insert an explicit cast of the op's source operand. This
1670 /// enables other canonicalization patterns that are matching for tensor_cast
1671 /// ops such as `ForOpTensorCastFolder` in SCF.
1672 ///
1673 /// Example:
1674 ///
1675 /// ```mlir
1676 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1677 /// : tensor<?x?xf32> into ...
1678 /// ```
1679 ///
1680 /// folds into:
1681 ///
1682 /// ```mlir
1683 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1684 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1685 /// : tensor<64x64xf32> into ...
1686 /// ```
1687 struct InsertSliceOpSourceCastInserter final
1688  : public OpRewritePattern<InsertSliceOp> {
1690 
1691  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1692  PatternRewriter &rewriter) const override {
1693  RankedTensorType srcType = insertSliceOp.getSourceType();
1694  if (srcType.getRank() != insertSliceOp.getType().getRank())
1695  return failure();
1696  SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1697  srcType.getShape().end());
1698  for (int64_t i = 0; i < srcType.getRank(); ++i) {
1699  if (Optional<int64_t> constInt =
1700  getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1701  newSrcShape[i] = *constInt;
1702  }
1703 
1704  RankedTensorType newSrcType =
1705  RankedTensorType::get(newSrcShape, srcType.getElementType());
1706  if (srcType == newSrcType ||
1707  !preservesStaticInformation(srcType, newSrcType) ||
1708  !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1709  return failure();
1710 
1711  // newSrcType is:
1712  // 1) Different from srcType.
1713  // 2) "More static" than srcType.
1714  // 3) Cast-compatible with srcType.
1715  // Insert the cast.
1716  Value cast = rewriter.create<tensor::CastOp>(
1717  insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1718  rewriter.replaceOpWithNewOp<InsertSliceOp>(
1719  insertSliceOp, cast, insertSliceOp.getDest(),
1720  insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1721  insertSliceOp.getMixedStrides());
1722  return success();
1723  }
1724 };
1725 } // namespace
1726 
1727 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1728  MLIRContext *context) {
1729  results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1730  InsertSliceOpSourceCastInserter>(context);
1731 }
1732 
1734  Location loc,
1735  Value tensor,
1736  Value dest) {
1737  auto rankedTensorType = dest.getType().cast<RankedTensorType>();
1738  unsigned rank = rankedTensorType.getRank();
1739  auto shape = rankedTensorType.getShape();
1740  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
1742  for (unsigned i = 0, e = rank; i < e; ++i) {
1743  OpFoldResult dim;
1744  if (rankedTensorType.isDynamicDim(i))
1745  dim = b.createOrFold<tensor::DimOp>(
1746  loc, dest, b.create<arith::ConstantIndexOp>(loc, i));
1747  else
1748  dim = b.getIndexAttr(shape[i]);
1749  sizes.push_back(dim);
1750  }
1751  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
1752  return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1753  sizes, strides);
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // PadOp
1758 //===----------------------------------------------------------------------===//
1759 
1760 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
1761 // supports optional types.
1762 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
1763  Type typeToInfer, Type typeToInferFrom) {}
1764 
1766  Optional<OpAsmParser::UnresolvedOperand> optOperand,
1767  Type &typeToInfer, Type typeToInferFrom) {
1768  if (optOperand)
1769  typeToInfer = typeToInferFrom;
1770  return success();
1771 }
1772 
1774  auto sourceType = getSource().getType().cast<RankedTensorType>();
1775  auto resultType = getResult().getType().cast<RankedTensorType>();
1776  auto expectedType = PadOp::inferResultType(
1777  sourceType, extractFromI64ArrayAttr(getStaticLow()),
1778  extractFromI64ArrayAttr(getStaticHigh()));
1779  for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
1780  if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1781  continue;
1782  if (expectedType.isDynamicDim(i))
1783  continue;
1784  return emitError("specified type ")
1785  << resultType << " does not match the inferred type "
1786  << expectedType;
1787  }
1788 
1789  return success();
1790 }
1791 
1792 LogicalResult PadOp::verifyRegions() {
1793  auto &region = getRegion();
1794  unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1795  Block &block = region.front();
1796  if (block.getNumArguments() != rank)
1797  return emitError("expected the block to have ") << rank << " arguments";
1798 
1799  // Note: the number and type of yield values are checked in the YieldOp.
1800  for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
1801  if (!en.value().isIndex())
1802  return emitOpError("expected block argument ")
1803  << (en.index() + 1) << " to be an index";
1804  }
1805 
1806  // Ensure that the region yields an element of the right type.
1807  auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
1808  if (yieldOp.getValue().getType() !=
1809  getType().cast<ShapedType>().getElementType())
1810  return emitOpError("expected yield type to match shape element type");
1811 
1812  return success();
1813 }
1814 
1815 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1816  ArrayRef<int64_t> staticLow,
1817  ArrayRef<int64_t> staticHigh,
1818  ArrayRef<int64_t> resultShape) {
1819  unsigned rank = sourceType.getRank();
1820  assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
1821  assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
1822  assert((resultShape.empty() || resultShape.size() == rank) &&
1823  "unexpected resultShape size mismatch");
1824 
1825  SmallVector<int64_t, 4> inferredShape;
1826  for (auto i : llvm::seq<unsigned>(0, rank)) {
1827  if (sourceType.isDynamicDim(i) ||
1828  staticLow[i] == ShapedType::kDynamicSize ||
1829  staticHigh[i] == ShapedType::kDynamicSize) {
1830  inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1831  : resultShape[i]);
1832  } else {
1833  int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1834  assert((resultShape.empty() || size == resultShape[i] ||
1835  resultShape[i] == ShapedType::kDynamicSize) &&
1836  "mismatch between inferred shape and result shape");
1837  inferredShape.push_back(size);
1838  }
1839  }
1840 
1841  return RankedTensorType::get(inferredShape, sourceType.getElementType());
1842 }
1843 
1844 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1845  ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
1846  ValueRange low, ValueRange high, bool nofold,
1847  ArrayRef<NamedAttribute> attrs) {
1848  auto sourceType = source.getType().cast<RankedTensorType>();
1849  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1850  build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
1851  b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
1852  result.addAttributes(attrs);
1853 }
1854 
1855 void PadOp::build(OpBuilder &b, OperationState &result, Value source,
1856  ValueRange low, ValueRange high, bool nofold,
1857  ArrayRef<NamedAttribute> attrs) {
1858  auto sourceType = source.getType().cast<RankedTensorType>();
1859  unsigned rank = sourceType.getRank();
1860  SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
1861  build(b, result, source, staticVector, staticVector, low, high, nofold,
1862  attrs);
1863 }
1864 
1865 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1866  Value source, ArrayRef<OpFoldResult> low,
1867  ArrayRef<OpFoldResult> high, bool nofold,
1868  ArrayRef<NamedAttribute> attrs) {
1869  assert(resultType.isa<RankedTensorType>());
1870  auto sourceType = source.getType().cast<RankedTensorType>();
1871  SmallVector<Value, 4> dynamicLow, dynamicHigh;
1872  SmallVector<int64_t, 4> staticLow, staticHigh;
1873  // staticLow and staticHigh have full information of the padding config.
1874  // This will grow staticLow and staticHigh with 1 value. If the config is
1875  // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
1876  // value as well.
1877  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
1878  ShapedType::kDynamicSize);
1879  dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
1880  ShapedType::kDynamicSize);
1881  if (!resultType) {
1882  resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1883  }
1884  build(b, result, resultType, source, dynamicLow, dynamicHigh,
1885  b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
1886  nofold ? b.getUnitAttr() : UnitAttr());
1887  result.addAttributes(attrs);
1888 }
1889 
1890 llvm::SmallBitVector PadOp::getPaddedDims() {
1891  llvm::SmallBitVector paddedDims(getSourceType().getRank());
1892  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1893  for (const auto &en : enumerate(paddingWidths))
1894  if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1895  paddedDims.set(en.index());
1896  };
1897  extractPaddedDims(getMixedLowPad());
1898  extractPaddedDims(getMixedHighPad());
1899  return paddedDims;
1900 }
1901 
1902 namespace {
1903 // Folds tensor.pad when padding is static zeros and the attribute
1904 // doesn't request otherwise.
1905 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
1907 
1908  LogicalResult matchAndRewrite(PadOp padTensorOp,
1909  PatternRewriter &rewriter) const override {
1910  if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1911  return failure();
1912  if (padTensorOp.getNofold())
1913  return failure();
1914  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1915  padTensorOp, padTensorOp.getResult().getType(),
1916  padTensorOp.getSource());
1917  return success();
1918  }
1919 };
1920 
1921 // Fold CastOp into PadOp when adding static information.
1922 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
1924 
1925  LogicalResult matchAndRewrite(PadOp padTensorOp,
1926  PatternRewriter &rewriter) const override {
1927  auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1928  if (!tensor::canFoldIntoConsumerOp(castOp))
1929  return failure();
1930 
1931  auto newResultType = PadOp::inferResultType(
1932  castOp.getSource().getType().cast<RankedTensorType>(),
1933  extractFromI64ArrayAttr(padTensorOp.getStaticLow()),
1934  extractFromI64ArrayAttr(padTensorOp.getStaticHigh()),
1935  padTensorOp.getResultType().getShape());
1936 
1937  if (newResultType == padTensorOp.getResultType()) {
1938  rewriter.updateRootInPlace(padTensorOp, [&]() {
1939  padTensorOp.getSourceMutable().assign(castOp.getSource());
1940  });
1941  } else {
1942  auto newOp = rewriter.create<PadOp>(
1943  padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
1944  padTensorOp.getLow(), padTensorOp.getHigh(),
1945  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1946  padTensorOp.getNofold());
1947  BlockAndValueMapping mapper;
1948  padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1949 
1950  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1951  padTensorOp, padTensorOp.getResultType(), newOp);
1952  }
1953  return success();
1954  }
1955 };
1956 
1957 // Fold CastOp using the result of PadOp back into the latter if it adds
1958 // static information.
1959 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
1961 
1962  LogicalResult matchAndRewrite(PadOp padTensorOp,
1963  PatternRewriter &rewriter) const override {
1964  if (!padTensorOp.getResult().hasOneUse())
1965  return failure();
1966  auto tensorCastOp =
1967  dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
1968  if (!tensorCastOp)
1969  return failure();
1970  if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
1971  tensorCastOp.getDest().getType()))
1972  return failure();
1973 
1974  auto replacementOp = rewriter.create<PadOp>(
1975  padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
1976  padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
1977  padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1978  padTensorOp.getNofold());
1979  replacementOp.getRegion().takeBody(padTensorOp.getRegion());
1980 
1981  rewriter.replaceOp(padTensorOp, replacementOp.getResult());
1982  rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
1983  return success();
1984  }
1985 };
1986 
1987 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
1988 /// different dimensions. The pattern applies if the following preconditions
1989 /// hold:
1990 /// 1) the tensor::ExtractSliceOps are not rank-reducing,
1991 /// 2) the tensor::ExtractSliceOps have only unit-strides,
1992 /// 3) the tensor::PadOps perform only high-padding,
1993 /// 4) the tensor::PadOps have the same constant padding value,
1994 /// 5) the tensor::PadOps do not have common padding dimensions,
1995 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
1996 /// zero-offset for every dimension.
1997 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
1998 /// padded source dimensions.
1999 ///
2000 /// Example:
2001 ///
2002 /// ```mlir
2003 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
2004 /// : tensor<64x64xf32> to tensor<?x64xf32>
2005 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
2006 /// } : tensor<?x64xf32> to tensor<8x64xf32>
2007 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
2008 /// : tensor<8x64xf32> to tensor<8x?xf32>
2009 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
2010 /// } : tensor<8x?xf32> to tensor<8x4xf32>
2011 /// ```
2012 ///
2013 /// folds into:
2014 ///
2015 /// ```mlir
2016 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
2017 /// : tensor<64x64xf32> to tensor<?x?xf32>
2018 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
2019 /// } : tensor<?x?xf32> to tensor<8x4xf32>
2020 /// ```
2021 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
2023 
2024  LogicalResult matchAndRewrite(PadOp padOp,
2025  PatternRewriter &rewriter) const override {
2026  auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2027  if (!innerSliceOp)
2028  return failure();
2029  auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2030  if (!outerPadOp || outerPadOp.getNofold())
2031  return failure();
2032  auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2033  if (!outerSliceOp)
2034  return failure();
2035 
2036  // 1) Fail if the chain is rank-reducing.
2037  int64_t rank = padOp.getSourceType().getRank();
2038  if (outerSliceOp.getSourceType().getRank() != rank) {
2039  return rewriter.notifyMatchFailure(padOp,
2040  "cannot fold rank-reducing chain");
2041  }
2042 
2043  // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2044  if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2045  return rewriter.notifyMatchFailure(
2046  padOp, "cannot fold non-unit stride ExtractSliceOps");
2047  }
2048 
2049  // 3) Fail if the tensor::PadOps have non-zero low padding.
2050  if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2051  return rewriter.notifyMatchFailure(padOp,
2052  "cannot fold PadOps with low padding");
2053  }
2054 
2055  // 4) Fail if the tensor::PadOps padding values do not match.
2056  Attribute innerAttr, outerAttr;
2057  Value innerValue = padOp.getConstantPaddingValue();
2058  Value outerValue = outerPadOp.getConstantPaddingValue();
2059  if (!innerValue || !outerValue ||
2060  !matchPattern(innerValue, m_Constant(&innerAttr)) ||
2061  !matchPattern(outerValue, m_Constant(&outerAttr)) ||
2062  innerAttr != outerAttr) {
2063  return rewriter.notifyMatchFailure(
2064  padOp, "cannot fold PadOps with different padding values");
2065  }
2066 
2067  // 5) Fail if a dimension is padded by both tensor::PadOps.
2068  llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2069  llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2070  if (innerDims.anyCommon(outerDims)) {
2071  return rewriter.notifyMatchFailure(
2072  padOp, "cannot fold PadOps with common padding dimensions");
2073  }
2074 
2075  // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2076  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2077  // for every dimension, and use the offset the other pair. Fail if no
2078  // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2079  // exists.
2080  SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2081  for (auto &en : enumerate(newOffsets)) {
2082  OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2083  OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2084  if (!innerDims.test(en.index()) &&
2085  (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2086  en.value() = outerOffset;
2087  continue;
2088  }
2089  if (!outerDims.test(en.index()) &&
2090  (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2091  en.value() = innerOffset;
2092  continue;
2093  }
2094  return rewriter.notifyMatchFailure(
2095  padOp, "cannot find zero-offset and zero-padding pair");
2096  }
2097 
2098  // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2099  // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2100  // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2101  // does not match the size of the padded dimension. Otherwise, take the size
2102  // of the inner tensor::ExtractSliceOp.
2103  SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2104  for (auto &en : enumerate(newSizes)) {
2105  if (!outerDims.test(en.index()))
2106  continue;
2107  OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2108  int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2109  assert(!ShapedType::isDynamic(sourceSize) &&
2110  "expected padded dimension to have a static size");
2111  if (getConstantIntValue(sliceSize) != sourceSize) {
2112  return rewriter.notifyMatchFailure(
2113  padOp, "cannot fold since the inner ExtractSliceOp size does not "
2114  "match the size of the outer padding");
2115  }
2116  en.value() = outerSliceOp.getMixedSizes()[en.index()];
2117  }
2118 
2119  // Combine the high paddings of the two tensor::PadOps.
2120  SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2121  for (auto &en : enumerate(newHighPad)) {
2122  if (innerDims.test(en.index()))
2123  newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2124  if (outerDims.test(en.index()))
2125  newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2126  }
2127 
2128  // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2129  // two paddings in one step.
2130  auto newSliceOp = rewriter.create<ExtractSliceOp>(
2131  padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2132  innerSliceOp.getMixedStrides());
2133  auto newPadOp = rewriter.create<PadOp>(
2134  padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2135  padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2136  rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2137  newPadOp.getRegion().begin());
2138  rewriter.replaceOp(padOp, newPadOp.getResult());
2139  return success();
2140  }
2141 };
2142 
2143 } // namespace
2144 
2145 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2146  MLIRContext *context) {
2147  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2148  FoldOrthogonalPaddings>(context);
2149 }
2150 
2151 /// Return the padding value of the PadOp if it constant. In this context,
2152 /// "constant" means an actual constant or "defined outside of the block".
2153 ///
2154 /// Values are considered constant in three cases:
2155 /// - A ConstantLike value.
2156 /// - A basic block argument from a different block.
2157 /// - A value defined outside of the block.
2158 ///
2159 /// If the padding value is not constant, an empty Value is returned.
2160 Value PadOp::getConstantPaddingValue() {
2161  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2162  if (!yieldOp)
2163  return {};
2164  Value padValue = yieldOp.getValue();
2165  // Check if yield value is a constant.
2166  if (matchPattern(padValue, m_Constant()))
2167  return padValue;
2168  // Check if yield value is defined inside the PadOp block.
2169  if (padValue.getParentBlock() == &getRegion().front())
2170  return {};
2171  // Else: Yield value defined outside of the PadOp block.
2172  return padValue;
2173 }
2174 
2175 OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
2176  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
2177  !getNofold())
2178  return getSource();
2179  return {};
2180 }
2181 
2182 //===----------------------------------------------------------------------===//
2183 // SplatOp
2184 //===----------------------------------------------------------------------===//
2185 
2186 OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2187  auto constOperand = operands.front();
2188  if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
2189  return {};
2190 
2191  // SplatElementsAttr::get treats single value for second arg as being a splat.
2192  return SplatElementsAttr::get(getType(), {constOperand});
2193 }
2194 
2195 //===----------------------------------------------------------------------===//
2196 // TableGen'd op method definitions
2197 //===----------------------------------------------------------------------===//
2198 
2199 #define GET_OP_CLASSES
2200 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:39
SmallVector< OpFoldResult, 4 > getMixedOffsets(OffsetSizeAndStrideOpInterface op, ArrayAttr staticOffsets, ValueRange offsets)
Return a vector of all the static or dynamic offsets of the op from provided external static and dyna...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
U cast() const
Definition: Location.h:67
An attribute that represents a reference to a dense float vector or tensor object.
MLIRContext * getContext() const
Definition: Builders.h:54
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:458
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:356
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
Definition: TensorOps.cpp:864
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:229
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:468
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2)
Definition: TensorOps.cpp:851
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayAttr staticOffsets, ArrayAttr staticSizes, ArrayAttr staticStrides, ShapedType *expectedType=nullptr)
Definition: TensorOps.cpp:1491
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
Definition: TensorOps.cpp:161
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:242
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:307
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< OpFoldResult, 4 > getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes)
Return a vector of all the static or dynamic sizes of the op from provided external static and dynami...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
An attribute that represents a reference to a dense vector or tensor object.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded...
Definition: TensorOps.cpp:132
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void addOperands(ValueRange newOperands)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
Definition: TensorOps.cpp:1530
U dyn_cast() const
Definition: Types.h:256
unsigned getNumArguments()
Definition: Block.h:119
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
UnitAttr getUnitAttr()
Definition: Builders.cpp:85
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:45
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:127
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:348
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
Operation::operand_range getIndices(Operation *op)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
bool isIndex() const
Definition: Types.cpp:28
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:331
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:1432
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:41
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:259
BlockArgListType getArguments()
Definition: Block.h:76
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:1762
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:77
auto getType() const
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:93
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Definition: TensorOps.cpp:1352
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
A canonicalizer wrapper to replace ExtractSliceOps.
Definition: TensorOps.cpp:1363
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Definition: TensorOps.cpp:1103
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:137
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Return the canonical type of the result of an extract_slice op.
Definition: TensorOps.cpp:1351
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Definition: TensorOps.cpp:1406
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
iterator end()
Definition: Region.h:56
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
U dyn_cast() const
Definition: Attributes.h:124
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:333
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:753
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef< AffineMap > reassociation)
Compute the RankedTensorType obtained by applying reassociation to type.
Definition: TensorOps.cpp:812
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
Type getElementType() const
Returns the element type of this tensor type.
This class represents an operand of an operation.
Definition: Value.h:251
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Definition: TensorOps.cpp:1364
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
ParseResult parseInferType(OpAsmParser &parser, Optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
Definition: TensorOps.cpp:1765
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
Definition: TensorOps.cpp:1384
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:1344
bool isa() const
Definition: Types.h:246
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor...
Definition: TensorOps.cpp:123
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
This class represents success/failure for parsing-like operations that find it important to chain tog...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
This class helps build Operations.
Definition: Builders.h:184
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
SmallVector< OpFoldResult, 4 > getMixedStrides(OffsetSizeAndStrideOpInterface op, ArrayAttr staticStrides, ValueRange strides)
Return a vector of all the static or dynamic strides of the op from provided external static and dyna...
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:484
MLIRContext * getContext() const
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:1733
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:262
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22