MLIR  20.0.0git
ExpandStridedMetadata.cpp
Go to the documentation of this file.
1 //===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// The pass expands memref operations that modify the metadata of a memref
10 /// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11 /// In particular, this pass transforms operations into explicit sequence of
12 /// operations that model the effect of this operation on the different
13 /// metadata. This pass uses affine constructs to materialize these effects.
14 //===----------------------------------------------------------------------===//
15 
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include <optional>
28 
29 namespace mlir {
30 namespace memref {
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33 } // namespace memref
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::affine;
38 
39 namespace {
40 
41 struct StridedMetadata {
42  Value basePtr;
43  OpFoldResult offset;
46 };
47 
48 /// From `subview(memref, subOffset, subSizes, subStrides))` compute
49 ///
50 /// \verbatim
51 /// baseBuffer, baseOffset, baseSizes, baseStrides =
52 /// extract_strided_metadata(memref)
53 /// strides#i = baseStrides#i * subStrides#i
54 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55 /// sizes = subSizes
56 /// \endverbatim
57 ///
58 /// and return {baseBuffer, offset, sizes, strides}
59 static FailureOr<StridedMetadata>
60 resolveSubviewStridedMetadata(RewriterBase &rewriter,
61  memref::SubViewOp subview) {
62  // Build a plain extract_strided_metadata(memref) from subview(memref).
63  Location origLoc = subview.getLoc();
64  Value source = subview.getSource();
65  auto sourceType = cast<MemRefType>(source.getType());
66  unsigned sourceRank = sourceType.getRank();
67 
68  auto newExtractStridedMetadata =
69  rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
70 
71  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
72 #ifndef NDEBUG
73  auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType());
74 #endif // NDEBUG
75 
76  // Compute the new strides and offset from the base strides and offset:
77  // newStride#i = baseStride#i * subStride#i
78  // offset = baseOffset + sum(subOffsets#i * newStrides#i)
80  SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
81  auto origStrides = newExtractStridedMetadata.getStrides();
82 
83  // Hold the affine symbols and values for the computation of the offset.
84  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
85  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
86 
87  bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
88  AffineExpr expr = symbols.front();
89  values[0] = ShapedType::isDynamic(sourceOffset)
90  ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
91  : rewriter.getIndexAttr(sourceOffset);
92  SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
93 
94  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
95  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
96  for (unsigned i = 0; i < sourceRank; ++i) {
97  // Compute the stride.
98  OpFoldResult origStride =
99  ShapedType::isDynamic(sourceStrides[i])
100  ? origStrides[i]
101  : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
102  strides.push_back(makeComposedFoldedAffineApply(
103  rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
104 
105  // Build up the computation of the offset.
106  unsigned baseIdxForDim = 1 + 2 * i;
107  unsigned subOffsetForDim = baseIdxForDim;
108  unsigned origStrideForDim = baseIdxForDim + 1;
109  expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110  values[subOffsetForDim] = subOffsets[i];
111  values[origStrideForDim] = origStride;
112  }
113 
114  // Compute the offset.
115  OpFoldResult finalOffset =
116  makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
117 #ifndef NDEBUG
118  // Assert that the computed offset matches the offset of the result type of
119  // the subview op (if both are static).
120  std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset);
121  if (computedOffset && !ShapedType::isDynamic(resultOffset))
122  assert(*computedOffset == resultOffset &&
123  "mismatch between computed offset and result type offset");
124 #endif // NDEBUG
125 
126  // The final result is <baseBuffer, offset, sizes, strides>.
127  // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
128  // the values.
129  auto subType = cast<MemRefType>(subview.getType());
130  unsigned subRank = subType.getRank();
131 
132  // The sizes of the final type are defined directly by the input sizes of
133  // the subview.
134  // Moreover subviews can drop some dimensions, some strides and sizes may
135  // not end up in the final <base, offset, sizes, strides> value that we are
136  // replacing.
137  // Do the filtering here.
138  SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
139  llvm::SmallBitVector droppedDims = subview.getDroppedDims();
140 
141  SmallVector<OpFoldResult> finalSizes;
142  finalSizes.reserve(subRank);
143 
144  SmallVector<OpFoldResult> finalStrides;
145  finalStrides.reserve(subRank);
146 
147 #ifndef NDEBUG
148  // Iteration variable for result dimensions of the subview op.
149  int64_t j = 0;
150 #endif // NDEBUG
151  for (unsigned i = 0; i < sourceRank; ++i) {
152  if (droppedDims.test(i))
153  continue;
154 
155  finalSizes.push_back(subSizes[i]);
156  finalStrides.push_back(strides[i]);
157 #ifndef NDEBUG
158  // Assert that the computed stride matches the stride of the result type of
159  // the subview op (if both are static).
160  std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161  if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
162  assert(*computedStride == resultStrides[j] &&
163  "mismatch between computed stride and result type stride");
164  ++j;
165 #endif // NDEBUG
166  }
167  assert(finalSizes.size() == subRank &&
168  "Should have populated all the values at this point");
169  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170  finalSizes, finalStrides};
171 }
172 
173 /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
174 /// With
175 ///
176 /// \verbatim
177 /// baseBuffer, baseOffset, baseSizes, baseStrides =
178 /// extract_strided_metadata(memref)
179 /// strides#i = baseStrides#i * subSizes#i
180 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
181 /// sizes = subSizes
182 /// dst = reinterpret_cast baseBuffer, offset, sizes, strides
183 /// \endverbatim
184 ///
185 /// In other words, get rid of the subview in that expression and canonicalize
186 /// on its effects on the offset, the sizes, and the strides using affine.apply.
187 struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
188 public:
190 
191  LogicalResult matchAndRewrite(memref::SubViewOp subview,
192  PatternRewriter &rewriter) const override {
193  FailureOr<StridedMetadata> stridedMetadata =
194  resolveSubviewStridedMetadata(rewriter, subview);
195  if (failed(stridedMetadata)) {
196  return rewriter.notifyMatchFailure(subview,
197  "failed to resolve subview metadata");
198  }
199 
200  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
201  subview, subview.getType(), stridedMetadata->basePtr,
202  stridedMetadata->offset, stridedMetadata->sizes,
203  stridedMetadata->strides);
204  return success();
205  }
206 };
207 
208 /// Pattern to replace `extract_strided_metadata(subview)`
209 /// With
210 ///
211 /// \verbatim
212 /// baseBuffer, baseOffset, baseSizes, baseStrides =
213 /// extract_strided_metadata(memref)
214 /// strides#i = baseStrides#i * subSizes#i
215 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
216 /// sizes = subSizes
217 /// \verbatim
218 ///
219 /// with `baseBuffer`, `offset`, `sizes` and `strides` being
220 /// the replacements for the original `extract_strided_metadata`.
221 struct ExtractStridedMetadataOpSubviewFolder
222  : OpRewritePattern<memref::ExtractStridedMetadataOp> {
224 
225  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
226  PatternRewriter &rewriter) const override {
227  auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
228  if (!subviewOp)
229  return failure();
230 
231  FailureOr<StridedMetadata> stridedMetadata =
232  resolveSubviewStridedMetadata(rewriter, subviewOp);
233  if (failed(stridedMetadata)) {
234  return rewriter.notifyMatchFailure(
235  op, "failed to resolve metadata in terms of source subview op");
236  }
237  Location loc = subviewOp.getLoc();
238  SmallVector<Value> results;
239  results.reserve(subviewOp.getType().getRank() * 2 + 2);
240  results.push_back(stridedMetadata->basePtr);
241  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
242  stridedMetadata->offset));
243  results.append(
244  getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
245  results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
246  stridedMetadata->strides));
247  rewriter.replaceOp(op, results);
248 
249  return success();
250  }
251 };
252 
253 /// Compute the expanded sizes of the given \p expandShape for the
254 /// \p groupId-th reassociation group.
255 /// \p origSizes hold the sizes of the source shape as values.
256 /// This is used to compute the new sizes in cases of dynamic shapes.
257 ///
258 /// sizes#i =
259 /// baseSizes#groupId / product(expandShapeSizes#j,
260 /// for j in group excluding reassIdx#i)
261 /// Where reassIdx#i is the reassociation index at index i in \p groupId.
262 ///
263 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264 ///
265 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
266 /// this is not possible because this function uses the Affine dialect and the
267 /// MemRef dialect cannot depend on the Affine dialect.
269 getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
270  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271  SmallVector<int64_t, 2> reassocGroup =
272  expandShape.getReassociationIndices()[groupId];
273  assert(!reassocGroup.empty() &&
274  "Reassociation group should have at least one dimension");
275 
276  unsigned groupSize = reassocGroup.size();
277  SmallVector<OpFoldResult> expandedSizes(groupSize);
278 
279  uint64_t productOfAllStaticSizes = 1;
280  std::optional<unsigned> dynSizeIdx;
281  MemRefType expandShapeType = expandShape.getResultType();
282 
283  // Fill up all the statically known sizes.
284  for (unsigned i = 0; i < groupSize; ++i) {
285  uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286  if (ShapedType::isDynamic(dimSize)) {
287  assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288  dynSizeIdx = i;
289  continue;
290  }
291  productOfAllStaticSizes *= dimSize;
292  expandedSizes[i] = builder.getIndexAttr(dimSize);
293  }
294 
295  // Compute the dynamic size using the original size and all the other known
296  // static sizes:
297  // expandSize = origSize / productOfAllStaticSizes.
298  if (dynSizeIdx) {
299  AffineExpr s0 = builder.getAffineSymbolExpr(0);
300  expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301  builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
302  origSizes[groupId]);
303  }
304 
305  return expandedSizes;
306 }
307 
308 /// Compute the expanded strides of the given \p expandShape for the
309 /// \p groupId-th reassociation group.
310 /// \p origStrides and \p origSizes hold respectively the strides and sizes
311 /// of the source shape as values.
312 /// This is used to compute the strides in cases of dynamic shapes and/or
313 /// dynamic stride for this reassociation group.
314 ///
315 /// strides#i =
316 /// origStrides#reassDim * product(expandShapeSizes#j, for j in
317 /// reassIdx#i+1..reassIdx#i+group.size-1)
318 ///
319 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
320 /// and expandShapeSizes#j is either:
321 /// - The constant size at dimension j, derived directly from the result type of
322 /// the expand_shape op, or
323 /// - An affine expression: baseSizes#reassDim / product of all constant sizes
324 /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325 /// element.)
326 ///
327 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328 ///
329 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
330 /// this is not possible because this function uses the Affine dialect and the
331 /// MemRef dialect cannot depend on the Affine dialect.
332 SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
333  OpBuilder &builder,
334  ArrayRef<OpFoldResult> origSizes,
335  ArrayRef<OpFoldResult> origStrides,
336  unsigned groupId) {
337  SmallVector<int64_t, 2> reassocGroup =
338  expandShape.getReassociationIndices()[groupId];
339  assert(!reassocGroup.empty() &&
340  "Reassociation group should have at least one dimension");
341 
342  unsigned groupSize = reassocGroup.size();
343  MemRefType expandShapeType = expandShape.getResultType();
344 
345  std::optional<int64_t> dynSizeIdx;
346 
347  // Fill up the expanded strides, with the information we can deduce from the
348  // resulting shape.
349  uint64_t currentStride = 1;
350  SmallVector<OpFoldResult> expandedStrides(groupSize);
351  for (int i = groupSize - 1; i >= 0; --i) {
352  expandedStrides[i] = builder.getIndexAttr(currentStride);
353  uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354  if (ShapedType::isDynamic(dimSize)) {
355  assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356  dynSizeIdx = i;
357  continue;
358  }
359 
360  currentStride *= dimSize;
361  }
362 
363  // Collect the statically known information about the original stride.
364  Value source = expandShape.getSrc();
365  auto sourceType = cast<MemRefType>(source.getType());
366  auto [strides, offset] = getStridesAndOffset(sourceType);
367 
368  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369  ? origStrides[groupId]
370  : builder.getIndexAttr(strides[groupId]);
371 
372  // Apply the original stride to all the strides.
373  int64_t doneStrideIdx = 0;
374  // If we saw a dynamic dimension, we need to fix-up all the strides up to
375  // that dimension with the dynamic size.
376  if (dynSizeIdx) {
377  int64_t productOfAllStaticSizes = currentStride;
378  assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379  "We shouldn't be able to change dynamicity");
380  OpFoldResult origSize = origSizes[groupId];
381 
382  AffineExpr s0 = builder.getAffineSymbolExpr(0);
383  AffineExpr s1 = builder.getAffineSymbolExpr(1);
384  for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385  int64_t baseExpandedStride =
386  cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
387  .getInt();
388  expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389  builder, expandShape.getLoc(),
390  (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391  {origSize, origStride});
392  }
393  }
394 
395  // Now apply the origStride to the remaining dimensions.
396  AffineExpr s0 = builder.getAffineSymbolExpr(0);
397  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398  int64_t baseExpandedStride =
399  cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
400  .getInt();
401  expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402  builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
403  }
404 
405  return expandedStrides;
406 }
407 
408 /// Produce an OpFoldResult object with \p builder at \p loc representing
409 /// `prod(valueOrConstant#i, for i in {indices})`,
410 /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411 /// values[i] otherwise.
412 ///
413 /// \pre for all index in indices: index < values.size()
414 /// \pre for all index in indices: index < maybeConstants.size()
415 static OpFoldResult
416 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417  ArrayRef<int64_t> maybeConstants,
418  ArrayRef<OpFoldResult> values,
419  llvm::function_ref<bool(int64_t)> isDynamic) {
420  AffineExpr productOfValues = builder.getAffineConstantExpr(1);
421  SmallVector<OpFoldResult> inputValues;
422  unsigned numberOfSymbols = 0;
423  unsigned groupSize = indices.size();
424  for (unsigned i = 0; i < groupSize; ++i) {
425  productOfValues =
426  productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
427  unsigned srcIdx = indices[i];
428  int64_t maybeConstant = maybeConstants[srcIdx];
429 
430  inputValues.push_back(isDynamic(maybeConstant)
431  ? values[srcIdx]
432  : builder.getIndexAttr(maybeConstant));
433  }
434 
435  return makeComposedFoldedAffineApply(builder, loc, productOfValues,
436  inputValues);
437 }
438 
439 /// Compute the collapsed size of the given \p collpaseShape for the
440 /// \p groupId-th reassociation group.
441 /// \p origSizes hold the sizes of the source shape as values.
442 /// This is used to compute the new sizes in cases of dynamic shapes.
443 ///
444 /// Conceptually this helper function computes:
445 /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
446 ///
447 /// \post result.size() == 1, in other words, each group collapse to one
448 /// dimension.
449 ///
450 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
451 /// this is not possible because this function uses the Affine dialect and the
452 /// MemRef dialect cannot depend on the Affine dialect.
454 getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
455  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456  SmallVector<OpFoldResult> collapsedSize;
457 
458  MemRefType collapseShapeType = collapseShape.getResultType();
459 
460  uint64_t size = collapseShapeType.getDimSize(groupId);
461  if (!ShapedType::isDynamic(size)) {
462  collapsedSize.push_back(builder.getIndexAttr(size));
463  return collapsedSize;
464  }
465 
466  // We are dealing with a dynamic size.
467  // Build the affine expr of the product of the original sizes involved in that
468  // group.
469  Value source = collapseShape.getSrc();
470  auto sourceType = cast<MemRefType>(source.getType());
471 
472  SmallVector<int64_t, 2> reassocGroup =
473  collapseShape.getReassociationIndices()[groupId];
474 
475  collapsedSize.push_back(getProductOfValues(
476  reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477  origSizes, ShapedType::isDynamic));
478 
479  return collapsedSize;
480 }
481 
482 /// Compute the collapsed stride of the given \p collpaseShape for the
483 /// \p groupId-th reassociation group.
484 /// \p origStrides and \p origSizes hold respectively the strides and sizes
485 /// of the source shape as values.
486 /// This is used to compute the strides in cases of dynamic shapes and/or
487 /// dynamic stride for this reassociation group.
488 ///
489 /// Conceptually this helper function returns the stride of the inner most
490 /// dimension of that group in the original shape.
491 ///
492 /// \post result.size() == 1, in other words, each group collapse to one
493 /// dimension.
495 getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
496  ArrayRef<OpFoldResult> origSizes,
497  ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498  SmallVector<int64_t, 2> reassocGroup =
499  collapseShape.getReassociationIndices()[groupId];
500  assert(!reassocGroup.empty() &&
501  "Reassociation group should have at least one dimension");
502 
503  Value source = collapseShape.getSrc();
504  auto sourceType = cast<MemRefType>(source.getType());
505 
506  auto [strides, offset] = getStridesAndOffset(sourceType);
507 
508  SmallVector<OpFoldResult> groupStrides;
509  ArrayRef<int64_t> srcShape = sourceType.getShape();
510  for (int64_t currentDim : reassocGroup) {
511  // Skip size-of-1 dimensions, since right now their strides may be
512  // meaningless.
513  // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
514  // they are truly contiguous. When they are truly contiguous, we shouldn't
515  // need to skip them.
516  if (srcShape[currentDim] == 1)
517  continue;
518 
519  int64_t currentStride = strides[currentDim];
520  groupStrides.push_back(ShapedType::isDynamic(currentStride)
521  ? origStrides[currentDim]
522  : builder.getIndexAttr(currentStride));
523  }
524  if (groupStrides.empty()) {
525  // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
526  // but we still have to make the type system happy.
527  MemRefType collapsedType = collapseShape.getResultType();
528  auto [collapsedStrides, collapsedOffset] =
529  getStridesAndOffset(collapsedType);
530  int64_t finalStride = collapsedStrides[groupId];
531  if (ShapedType::isDynamic(finalStride)) {
532  // Look for a dynamic stride. At this point we don't know which one is
533  // desired, but they are all equally good/bad.
534  for (int64_t currentDim : reassocGroup) {
535  assert(srcShape[currentDim] == 1 &&
536  "We should be dealing with 1x1x...x1");
537 
538  if (ShapedType::isDynamic(strides[currentDim]))
539  return {origStrides[currentDim]};
540  }
541  llvm_unreachable("We should have found a dynamic stride");
542  }
543  return {builder.getIndexAttr(finalStride)};
544  }
545 
546  // For the general case, we just want the minimum stride
547  // since the collapsed dimensions are contiguous.
548  auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
549  builder.getContext());
550  return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
551  groupStrides)};
552 }
553 
554 /// From `reshape_like(memref, subSizes, subStrides))` compute
555 ///
556 /// \verbatim
557 /// baseBuffer, baseOffset, baseSizes, baseStrides =
558 /// extract_strided_metadata(memref)
559 /// strides#i = baseStrides#i * subStrides#i
560 /// sizes = subSizes
561 /// \endverbatim
562 ///
563 /// and return {baseBuffer, baseOffset, sizes, strides}
564 template <typename ReassociativeReshapeLikeOp>
565 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
566  RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
568  ReassociativeReshapeLikeOp, OpBuilder &,
569  ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
570  getReshapedSizes,
572  ReassociativeReshapeLikeOp, OpBuilder &,
573  ArrayRef<OpFoldResult> /*origSizes*/,
574  ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
575  getReshapedStrides) {
576  // Build a plain extract_strided_metadata(memref) from
577  // extract_strided_metadata(reassociative_reshape_like(memref)).
578  Location origLoc = reshape.getLoc();
579  Value source = reshape.getSrc();
580  auto sourceType = cast<MemRefType>(source.getType());
581  unsigned sourceRank = sourceType.getRank();
582 
583  auto newExtractStridedMetadata =
584  rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
585 
586  // Collect statically known information.
587  auto [strides, offset] = getStridesAndOffset(sourceType);
588  MemRefType reshapeType = reshape.getResultType();
589  unsigned reshapeRank = reshapeType.getRank();
590 
591  OpFoldResult offsetOfr =
592  ShapedType::isDynamic(offset)
593  ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
594  : rewriter.getIndexAttr(offset);
595 
596  // Get the special case of 0-D out of the way.
597  if (sourceRank == 0) {
598  SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
599  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600  /*sizes=*/ones, /*strides=*/ones};
601  }
602 
603  SmallVector<OpFoldResult> finalSizes;
604  finalSizes.reserve(reshapeRank);
605  SmallVector<OpFoldResult> finalStrides;
606  finalStrides.reserve(reshapeRank);
607 
608  // Compute the reshaped strides and sizes from the base strides and sizes.
609  SmallVector<OpFoldResult> origSizes =
610  getAsOpFoldResult(newExtractStridedMetadata.getSizes());
611  SmallVector<OpFoldResult> origStrides =
612  getAsOpFoldResult(newExtractStridedMetadata.getStrides());
613  unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
614  for (; idx != endIdx; ++idx) {
615  SmallVector<OpFoldResult> reshapedSizes =
616  getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
617  SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
618  reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
619 
620  unsigned groupSize = reshapedSizes.size();
621  for (unsigned i = 0; i < groupSize; ++i) {
622  finalSizes.push_back(reshapedSizes[i]);
623  finalStrides.push_back(reshapedStrides[i]);
624  }
625  }
626  assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
627  (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
628  "We should have visited all the input dimensions");
629  assert(finalSizes.size() == reshapeRank &&
630  "We should have populated all the values");
631 
632  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
633  finalSizes, finalStrides};
634 }
635 
636 /// Replace `baseBuffer, offset, sizes, strides =
637 /// extract_strided_metadata(reshapeLike(memref))`
638 /// With
639 ///
640 /// \verbatim
641 /// baseBuffer, offset, baseSizes, baseStrides =
642 /// extract_strided_metadata(memref)
643 /// sizes = getReshapedSizes(reshapeLike)
644 /// strides = getReshapedStrides(reshapeLike)
645 /// \endverbatim
646 ///
647 ///
648 /// Notice that `baseBuffer` and `offset` are unchanged.
649 ///
650 /// In other words, get rid of the expand_shape in that expression and
651 /// materialize its effects on the sizes and the strides using affine apply.
652 template <typename ReassociativeReshapeLikeOp,
653  SmallVector<OpFoldResult> (*getReshapedSizes)(
654  ReassociativeReshapeLikeOp, OpBuilder &,
655  ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
656  SmallVector<OpFoldResult> (*getReshapedStrides)(
657  ReassociativeReshapeLikeOp, OpBuilder &,
658  ArrayRef<OpFoldResult> /*origSizes*/,
659  ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
660 struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
661 public:
663 
664  LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
665  PatternRewriter &rewriter) const override {
666  FailureOr<StridedMetadata> stridedMetadata =
667  resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
668  rewriter, reshape, getReshapedSizes, getReshapedStrides);
669  if (failed(stridedMetadata)) {
670  return rewriter.notifyMatchFailure(reshape,
671  "failed to resolve reshape metadata");
672  }
673 
674  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
675  reshape, reshape.getType(), stridedMetadata->basePtr,
676  stridedMetadata->offset, stridedMetadata->sizes,
677  stridedMetadata->strides);
678  return success();
679  }
680 };
681 
682 /// Pattern to replace `extract_strided_metadata(collapse_shape)`
683 /// With
684 ///
685 /// \verbatim
686 /// baseBuffer, baseOffset, baseSizes, baseStrides =
687 /// extract_strided_metadata(memref)
688 /// strides#i = baseStrides#i * subSizes#i
689 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
690 /// sizes = subSizes
691 /// \verbatim
692 ///
693 /// with `baseBuffer`, `offset`, `sizes` and `strides` being
694 /// the replacements for the original `extract_strided_metadata`.
695 struct ExtractStridedMetadataOpCollapseShapeFolder
698 
699  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
700  PatternRewriter &rewriter) const override {
701  auto collapseShapeOp =
702  op.getSource().getDefiningOp<memref::CollapseShapeOp>();
703  if (!collapseShapeOp)
704  return failure();
705 
706  FailureOr<StridedMetadata> stridedMetadata =
707  resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
708  rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
709  if (failed(stridedMetadata)) {
710  return rewriter.notifyMatchFailure(
711  op,
712  "failed to resolve metadata in terms of source collapse_shape op");
713  }
714 
715  Location loc = collapseShapeOp.getLoc();
716  SmallVector<Value> results;
717  results.push_back(stridedMetadata->basePtr);
718  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
719  stridedMetadata->offset));
720  results.append(
721  getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
722  results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
723  stridedMetadata->strides));
724  rewriter.replaceOp(op, results);
725  return success();
726  }
727 };
728 
729 /// Pattern to replace `extract_strided_metadata(expand_shape)`
730 /// with the results of computing the sizes and strides on the expanded shape
731 /// and dividing up dimensions into static and dynamic parts as needed.
732 struct ExtractStridedMetadataOpExpandShapeFolder
735 
736  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
737  PatternRewriter &rewriter) const override {
738  auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
739  if (!expandShapeOp)
740  return failure();
741 
742  FailureOr<StridedMetadata> stridedMetadata =
743  resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
744  rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
745  if (failed(stridedMetadata)) {
746  return rewriter.notifyMatchFailure(
747  op, "failed to resolve metadata in terms of source expand_shape op");
748  }
749 
750  Location loc = expandShapeOp.getLoc();
751  SmallVector<Value> results;
752  results.push_back(stridedMetadata->basePtr);
753  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
754  stridedMetadata->offset));
755  results.append(
756  getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
757  results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
758  stridedMetadata->strides));
759  rewriter.replaceOp(op, results);
760  return success();
761  }
762 };
763 
764 /// Replace `base, offset, sizes, strides =
765 /// extract_strided_metadata(allocLikeOp)`
766 ///
767 /// With
768 ///
769 /// ```
770 /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
771 /// offset = 0
772 /// sizes = allocSizes
773 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
774 /// ```
775 ///
776 /// The transformation only applies if the allocLikeOp has been normalized.
777 /// In other words, the affine_map must be an identity.
778 template <typename AllocLikeOp>
779 struct ExtractStridedMetadataOpAllocFolder
781 public:
783 
784  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
785  PatternRewriter &rewriter) const override {
786  auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
787  if (!allocLikeOp)
788  return failure();
789 
790  auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
791  if (!memRefType.getLayout().isIdentity())
792  return rewriter.notifyMatchFailure(
793  allocLikeOp, "alloc-like operations should have been normalized");
794 
795  Location loc = op.getLoc();
796  int rank = memRefType.getRank();
797 
798  // Collect the sizes.
799  ValueRange dynamic = allocLikeOp.getDynamicSizes();
801  sizes.reserve(rank);
802  unsigned dynamicPos = 0;
803  for (int64_t size : memRefType.getShape()) {
804  if (ShapedType::isDynamic(size))
805  sizes.push_back(dynamic[dynamicPos++]);
806  else
807  sizes.push_back(rewriter.getIndexAttr(size));
808  }
809 
810  // Strides (just creates identity strides).
811  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
812  AffineExpr expr = rewriter.getAffineConstantExpr(1);
813  unsigned symbolNumber = 0;
814  for (int i = rank - 2; i >= 0; --i) {
815  expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++);
816  assert(i + 1 + symbolNumber == sizes.size() &&
817  "The ArrayRef should encompass the last #symbolNumber sizes");
818  ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
819  strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
820  sizesInvolvedInStride);
821  }
822 
823  // Put all the values together to replace the results.
824  SmallVector<Value> results;
825  results.reserve(rank * 2 + 2);
826 
827  auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
828  int64_t offset = 0;
829  if (op.getBaseBuffer().use_empty()) {
830  results.push_back(nullptr);
831  } else {
832  if (allocLikeOp.getType() == baseBufferType)
833  results.push_back(allocLikeOp);
834  else
835  results.push_back(rewriter.create<memref::ReinterpretCastOp>(
836  loc, baseBufferType, allocLikeOp, offset,
837  /*sizes=*/ArrayRef<int64_t>(),
838  /*strides=*/ArrayRef<int64_t>()));
839  }
840 
841  // Offset.
842  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
843 
844  for (OpFoldResult size : sizes)
845  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
846 
847  for (OpFoldResult stride : strides)
848  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
849 
850  rewriter.replaceOp(op, results);
851  return success();
852  }
853 };
854 
855 /// Replace `base, offset, sizes, strides =
856 /// extract_strided_metadata(get_global)`
857 ///
858 /// With
859 ///
860 /// ```
861 /// base = reinterpret_cast get_global to a flat memref<eltTy>
862 /// offset = 0
863 /// sizes = allocSizes
864 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
865 /// ```
866 ///
867 /// It is expected that the memref.get_global op has static shapes
868 /// and identity affine_map for the layout.
869 struct ExtractStridedMetadataOpGetGlobalFolder
870  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
871 public:
873 
874  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
875  PatternRewriter &rewriter) const override {
876  auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
877  if (!getGlobalOp)
878  return failure();
879 
880  auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
881  if (!memRefType.getLayout().isIdentity()) {
882  return rewriter.notifyMatchFailure(
883  getGlobalOp,
884  "get-global operation result should have been normalized");
885  }
886 
887  Location loc = op.getLoc();
888  int rank = memRefType.getRank();
889 
890  // Collect the sizes.
891  ArrayRef<int64_t> sizes = memRefType.getShape();
892  assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
893  "unexpected dynamic shape for result of `memref.get_global` op");
894 
895  // Strides (just creates identity strides).
897 
898  // Put all the values together to replace the results.
899  SmallVector<Value> results;
900  results.reserve(rank * 2 + 2);
901 
902  auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
903  int64_t offset = 0;
904  if (getGlobalOp.getType() == baseBufferType)
905  results.push_back(getGlobalOp);
906  else
907  results.push_back(rewriter.create<memref::ReinterpretCastOp>(
908  loc, baseBufferType, getGlobalOp, offset,
909  /*sizes=*/ArrayRef<int64_t>(),
910  /*strides=*/ArrayRef<int64_t>()));
911 
912  // Offset.
913  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
914 
915  for (auto size : sizes)
916  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
917 
918  for (auto stride : strides)
919  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
920 
921  rewriter.replaceOp(op, results);
922  return success();
923  }
924 };
925 
926 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
927 /// source of the ViewLikeOp.
928 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
929  : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
931 
932  LogicalResult
933  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
934  PatternRewriter &rewriter) const override {
935  auto viewLikeOp =
936  extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
937  if (!viewLikeOp)
938  return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
939  rewriter.modifyOpInPlace(extractOp, [&]() {
940  extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
941  });
942  return success();
943  }
944 };
945 
946 /// Replace `base, offset, sizes, strides =
947 /// extract_strided_metadata(
948 /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
949 /// With
950 /// ```
951 /// base, ... = extract_strided_metadata(src)
952 /// offset = srcOffset
953 /// sizes = srcSizes
954 /// strides = srcStrides
955 /// ```
956 ///
957 /// In other words, consume the `reinterpret_cast` and apply its effects
958 /// on the offset, sizes, and strides.
959 class ExtractStridedMetadataOpReinterpretCastFolder
960  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
962 
963  LogicalResult
964  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
965  PatternRewriter &rewriter) const override {
966  auto reinterpretCastOp = extractStridedMetadataOp.getSource()
967  .getDefiningOp<memref::ReinterpretCastOp>();
968  if (!reinterpretCastOp)
969  return failure();
970 
971  Location loc = extractStridedMetadataOp.getLoc();
972  // Check if the source is suitable for extract_strided_metadata.
973  SmallVector<Type> inferredReturnTypes;
974  if (failed(extractStridedMetadataOp.inferReturnTypes(
975  rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
976  /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
977  inferredReturnTypes)))
978  return rewriter.notifyMatchFailure(
979  reinterpretCastOp, "reinterpret_cast source's type is incompatible");
980 
981  auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
982  unsigned rank = memrefType.getRank();
984  results.resize_for_overwrite(rank * 2 + 2);
985 
986  auto newExtractStridedMetadata =
987  rewriter.create<memref::ExtractStridedMetadataOp>(
988  loc, reinterpretCastOp.getSource());
989 
990  // Register the base_buffer.
991  results[0] = newExtractStridedMetadata.getBaseBuffer();
992 
993  // Register the new offset.
994  results[1] = getValueOrCreateConstantIndexOp(
995  rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
996 
997  const unsigned sizeStartIdx = 2;
998  const unsigned strideStartIdx = sizeStartIdx + rank;
999 
1000  SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
1001  SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
1002  for (unsigned i = 0; i < rank; ++i) {
1003  results[sizeStartIdx + i] = sizes[i];
1004  results[strideStartIdx + i] = strides[i];
1005  }
1006  rewriter.replaceOp(extractStridedMetadataOp,
1007  getValueOrCreateConstantIndexOp(rewriter, loc, results));
1008  return success();
1009  }
1010 };
1011 
1012 /// Replace `base, offset, sizes, strides =
1013 /// extract_strided_metadata(
1014 /// cast(src) to dstTy)`
1015 /// With
1016 /// ```
1017 /// base, ... = extract_strided_metadata(src)
1018 /// offset = !dstTy.srcOffset.isDynamic()
1019 /// ? dstTy.srcOffset
1020 /// : extract_strided_metadata(src).offset
1021 /// sizes = for each srcSize in dstTy.srcSizes:
1022 /// !srcSize.isDynamic()
1023 /// ? srcSize
1024 // : extract_strided_metadata(src).sizes[i]
1025 /// strides = for each srcStride in dstTy.srcStrides:
1026 /// !srcStrides.isDynamic()
1027 /// ? srcStrides
1028 /// : extract_strided_metadata(src).strides[i]
1029 /// ```
1030 ///
1031 /// In other words, consume the `cast` and apply its effects
1032 /// on the offset, sizes, and strides or compute them directly from `src`.
1033 class ExtractStridedMetadataOpCastFolder
1034  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1036 
1037  LogicalResult
1038  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1039  PatternRewriter &rewriter) const override {
1040  Value source = extractStridedMetadataOp.getSource();
1041  auto castOp = source.getDefiningOp<memref::CastOp>();
1042  if (!castOp)
1043  return failure();
1044 
1045  Location loc = extractStridedMetadataOp.getLoc();
1046  // Check if the source is suitable for extract_strided_metadata.
1047  SmallVector<Type> inferredReturnTypes;
1048  if (failed(extractStridedMetadataOp.inferReturnTypes(
1049  rewriter.getContext(), loc, {castOp.getSource()},
1050  /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1051  inferredReturnTypes)))
1052  return rewriter.notifyMatchFailure(castOp,
1053  "cast source's type is incompatible");
1054 
1055  auto memrefType = cast<MemRefType>(source.getType());
1056  unsigned rank = memrefType.getRank();
1057  SmallVector<OpFoldResult> results;
1058  results.resize_for_overwrite(rank * 2 + 2);
1059 
1060  auto newExtractStridedMetadata =
1061  rewriter.create<memref::ExtractStridedMetadataOp>(loc,
1062  castOp.getSource());
1063 
1064  // Register the base_buffer.
1065  results[0] = newExtractStridedMetadata.getBaseBuffer();
1066 
1067  auto getConstantOrValue = [&rewriter](int64_t constant,
1068  OpFoldResult ofr) -> OpFoldResult {
1069  return !ShapedType::isDynamic(constant)
1070  ? OpFoldResult(rewriter.getIndexAttr(constant))
1071  : ofr;
1072  };
1073 
1074  auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
1075  assert(sourceStrides.size() == rank && "unexpected number of strides");
1076 
1077  // Register the new offset.
1078  results[1] =
1079  getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1080 
1081  const unsigned sizeStartIdx = 2;
1082  const unsigned strideStartIdx = sizeStartIdx + rank;
1083  ArrayRef<int64_t> sourceSizes = memrefType.getShape();
1084 
1085  SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
1086  SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
1087  for (unsigned i = 0; i < rank; ++i) {
1088  results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1089  results[strideStartIdx + i] =
1090  getConstantOrValue(sourceStrides[i], strides[i]);
1091  }
1092  rewriter.replaceOp(extractStridedMetadataOp,
1093  getValueOrCreateConstantIndexOp(rewriter, loc, results));
1094  return success();
1095  }
1096 };
1097 
1098 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
1099 /// memory_space_cast(src) to dstTy)`
1100 /// with
1101 /// ```
1102 /// oldBase, offset, sizes, strides = extract_strided_metadata(src)
1103 /// destBaseTy = type(oldBase) with memory space from destTy
1104 /// base = memory_space_cast(oldBase) to destBaseTy
1105 /// ```
1106 ///
1107 /// In other words, propagate metadata extraction accross memory space casts.
1108 class ExtractStridedMetadataOpMemorySpaceCastFolder
1109  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1111 
1112  LogicalResult
1113  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1114  PatternRewriter &rewriter) const override {
1115  Location loc = extractStridedMetadataOp.getLoc();
1116  Value source = extractStridedMetadataOp.getSource();
1117  auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1118  if (!memSpaceCastOp)
1119  return failure();
1120  auto newExtractStridedMetadata =
1121  rewriter.create<memref::ExtractStridedMetadataOp>(
1122  loc, memSpaceCastOp.getSource());
1123  SmallVector<Value> results(newExtractStridedMetadata.getResults());
1124  // As with most other strided metadata rewrite patterns, don't introduce
1125  // a use of the base pointer where non existed. This needs to happen here,
1126  // as opposed to in later dead-code elimination, because these patterns are
1127  // sometimes used during dialect conversion (see EmulateNarrowType, for
1128  // example), so adding spurious usages would cause a pre-legalization value
1129  // to be live that would be dead had this pattern not run.
1130  if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1131  auto baseBuffer = results[0];
1132  auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1133  MemRefType::Builder newTypeBuilder(baseBufferType);
1134  newTypeBuilder.setMemorySpace(
1135  memSpaceCastOp.getResult().getType().getMemorySpace());
1136  results[0] = rewriter.create<memref::MemorySpaceCastOp>(
1137  loc, Type{newTypeBuilder}, baseBuffer);
1138  } else {
1139  results[0] = nullptr;
1140  }
1141  rewriter.replaceOp(extractStridedMetadataOp, results);
1142  return success();
1143  }
1144 };
1145 
1146 /// Replace `base, offset =
1147 /// extract_strided_metadata(extract_strided_metadata(src)#0)`
1148 /// With
1149 /// ```
1150 /// base, ... = extract_strided_metadata(src)
1151 /// offset = 0
1152 /// ```
1153 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1154  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1156 
1157  LogicalResult
1158  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1159  PatternRewriter &rewriter) const override {
1160  auto sourceExtractStridedMetadataOp =
1161  extractStridedMetadataOp.getSource()
1162  .getDefiningOp<memref::ExtractStridedMetadataOp>();
1163  if (!sourceExtractStridedMetadataOp)
1164  return failure();
1165  Location loc = extractStridedMetadataOp.getLoc();
1166  rewriter.replaceOp(extractStridedMetadataOp,
1167  {sourceExtractStridedMetadataOp.getBaseBuffer(),
1169  rewriter, loc, rewriter.getIndexAttr(0))});
1170  return success();
1171  }
1172 };
1173 } // namespace
1174 
1176  RewritePatternSet &patterns) {
1177  patterns.add<SubviewFolder,
1178  ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1179  getExpandedStrides>,
1180  ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1181  getCollapsedStride>,
1182  ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1183  ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1184  ExtractStridedMetadataOpCollapseShapeFolder,
1185  ExtractStridedMetadataOpExpandShapeFolder,
1186  ExtractStridedMetadataOpGetGlobalFolder,
1187  RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1188  ExtractStridedMetadataOpReinterpretCastFolder,
1189  ExtractStridedMetadataOpSubviewFolder,
1190  ExtractStridedMetadataOpCastFolder,
1191  ExtractStridedMetadataOpMemorySpaceCastFolder,
1192  ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1193  patterns.getContext());
1194 }
1195 
1197  RewritePatternSet &patterns) {
1198  patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1199  ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1200  ExtractStridedMetadataOpCollapseShapeFolder,
1201  ExtractStridedMetadataOpExpandShapeFolder,
1202  ExtractStridedMetadataOpGetGlobalFolder,
1203  ExtractStridedMetadataOpSubviewFolder,
1204  RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1205  ExtractStridedMetadataOpReinterpretCastFolder,
1206  ExtractStridedMetadataOpCastFolder,
1207  ExtractStridedMetadataOpMemorySpaceCastFolder,
1208  ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1209  patterns.getContext());
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // Pass registration
1214 //===----------------------------------------------------------------------===//
1215 
1216 namespace {
1217 
1218 struct ExpandStridedMetadataPass final
1219  : public memref::impl::ExpandStridedMetadataBase<
1220  ExpandStridedMetadataPass> {
1221  void runOnOperation() override;
1222 };
1223 
1224 } // namespace
1225 
1226 void ExpandStridedMetadataPass::runOnOperation() {
1227  RewritePatternSet patterns(&getContext());
1229  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1230 }
1231 
1232 std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
1233  return std::make_unique<ExpandStridedMetadataPass>();
1234 }
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:907
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:387
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:391
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:210
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1298
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:367
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.