MLIR  18.0.0git
ExpandStridedMetadata.cpp
Go to the documentation of this file.
1 //===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// The pass expands memref operations that modify the metadata of a memref
10 /// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11 /// In particular, this pass transforms operations into explicit sequence of
12 /// operations that model the effect of this operation on the different
13 /// metadata. This pass uses affine constructs to materialize these effects.
14 //===----------------------------------------------------------------------===//
15 
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include <optional>
28 
29 namespace mlir {
30 namespace memref {
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33 } // namespace memref
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::affine;
38 
39 namespace {
40 
41 struct StridedMetadata {
42  Value basePtr;
43  OpFoldResult offset;
46 };
47 
48 /// From `subview(memref, subOffset, subSizes, subStrides))` compute
49 ///
50 /// \verbatim
51 /// baseBuffer, baseOffset, baseSizes, baseStrides =
52 /// extract_strided_metadata(memref)
53 /// strides#i = baseStrides#i * subSizes#i
54 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55 /// sizes = subSizes
56 /// \endverbatim
57 ///
58 /// and return {baseBuffer, offset, sizes, strides}
60 resolveSubviewStridedMetadata(RewriterBase &rewriter,
61  memref::SubViewOp subview) {
62  // Build a plain extract_strided_metadata(memref) from subview(memref).
63  Location origLoc = subview.getLoc();
64  Value source = subview.getSource();
65  auto sourceType = cast<MemRefType>(source.getType());
66  unsigned sourceRank = sourceType.getRank();
67 
68  auto newExtractStridedMetadata =
69  rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
70 
71  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
72 
73  // Compute the new strides and offset from the base strides and offset:
74  // newStride#i = baseStride#i * subStride#i
75  // offset = baseOffset + sum(subOffsets#i * newStrides#i)
77  SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
78  auto origStrides = newExtractStridedMetadata.getStrides();
79 
80  // Hold the affine symbols and values for the computation of the offset.
81  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
82  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
83 
84  bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
85  AffineExpr expr = symbols.front();
86  values[0] = ShapedType::isDynamic(sourceOffset)
87  ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
88  : rewriter.getIndexAttr(sourceOffset);
89  SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
90 
91  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
92  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
93  for (unsigned i = 0; i < sourceRank; ++i) {
94  // Compute the stride.
95  OpFoldResult origStride =
96  ShapedType::isDynamic(sourceStrides[i])
97  ? origStrides[i]
98  : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
99  strides.push_back(makeComposedFoldedAffineApply(
100  rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
101 
102  // Build up the computation of the offset.
103  unsigned baseIdxForDim = 1 + 2 * i;
104  unsigned subOffsetForDim = baseIdxForDim;
105  unsigned origStrideForDim = baseIdxForDim + 1;
106  expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
107  values[subOffsetForDim] = subOffsets[i];
108  values[origStrideForDim] = origStride;
109  }
110 
111  // Compute the offset.
112  OpFoldResult finalOffset =
113  makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
114 
115  // The final result is <baseBuffer, offset, sizes, strides>.
116  // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
117  // the values.
118  auto subType = cast<MemRefType>(subview.getType());
119  unsigned subRank = subType.getRank();
120 
121  // The sizes of the final type are defined directly by the input sizes of
122  // the subview.
123  // Moreover subviews can drop some dimensions, some strides and sizes may
124  // not end up in the final <base, offset, sizes, strides> value that we are
125  // replacing.
126  // Do the filtering here.
127  SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
128  llvm::SmallBitVector droppedDims = subview.getDroppedDims();
129 
130  SmallVector<OpFoldResult> finalSizes;
131  finalSizes.reserve(subRank);
132 
133  SmallVector<OpFoldResult> finalStrides;
134  finalStrides.reserve(subRank);
135 
136  for (unsigned i = 0; i < sourceRank; ++i) {
137  if (droppedDims.test(i))
138  continue;
139 
140  finalSizes.push_back(subSizes[i]);
141  finalStrides.push_back(strides[i]);
142  }
143  assert(finalSizes.size() == subRank &&
144  "Should have populated all the values at this point");
145  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
146  finalSizes, finalStrides};
147 }
148 
149 /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
150 /// With
151 ///
152 /// \verbatim
153 /// baseBuffer, baseOffset, baseSizes, baseStrides =
154 /// extract_strided_metadata(memref)
155 /// strides#i = baseStrides#i * subSizes#i
156 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
157 /// sizes = subSizes
158 /// dst = reinterpret_cast baseBuffer, offset, sizes, strides
159 /// \endverbatim
160 ///
161 /// In other words, get rid of the subview in that expression and canonicalize
162 /// on its effects on the offset, the sizes, and the strides using affine.apply.
163 struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
164 public:
166 
167  LogicalResult matchAndRewrite(memref::SubViewOp subview,
168  PatternRewriter &rewriter) const override {
169  FailureOr<StridedMetadata> stridedMetadata =
170  resolveSubviewStridedMetadata(rewriter, subview);
171  if (failed(stridedMetadata)) {
172  return rewriter.notifyMatchFailure(subview,
173  "failed to resolve subview metadata");
174  }
175 
176  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
177  subview, subview.getType(), stridedMetadata->basePtr,
178  stridedMetadata->offset, stridedMetadata->sizes,
179  stridedMetadata->strides);
180  return success();
181  }
182 };
183 
184 /// Pattern to replace `extract_strided_metadata(subview)`
185 /// With
186 ///
187 /// \verbatim
188 /// baseBuffer, baseOffset, baseSizes, baseStrides =
189 /// extract_strided_metadata(memref)
190 /// strides#i = baseStrides#i * subSizes#i
191 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
192 /// sizes = subSizes
193 /// \verbatim
194 ///
195 /// with `baseBuffer`, `offset`, `sizes` and `strides` being
196 /// the replacements for the original `extract_strided_metadata`.
197 struct ExtractStridedMetadataOpSubviewFolder
198  : OpRewritePattern<memref::ExtractStridedMetadataOp> {
200 
201  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
202  PatternRewriter &rewriter) const override {
203  auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
204  if (!subviewOp)
205  return failure();
206 
207  FailureOr<StridedMetadata> stridedMetadata =
208  resolveSubviewStridedMetadata(rewriter, subviewOp);
209  if (failed(stridedMetadata)) {
210  return rewriter.notifyMatchFailure(
211  op, "failed to resolve metadata in terms of source subview op");
212  }
213  Location loc = subviewOp.getLoc();
214  SmallVector<Value> results;
215  results.reserve(subviewOp.getType().getRank() * 2 + 2);
216  results.push_back(stridedMetadata->basePtr);
217  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
218  stridedMetadata->offset));
219  results.append(
220  getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
221  results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
222  stridedMetadata->strides));
223  rewriter.replaceOp(op, results);
224 
225  return success();
226  }
227 };
228 
229 /// Compute the expanded sizes of the given \p expandShape for the
230 /// \p groupId-th reassociation group.
231 /// \p origSizes hold the sizes of the source shape as values.
232 /// This is used to compute the new sizes in cases of dynamic shapes.
233 ///
234 /// sizes#i =
235 /// baseSizes#groupId / product(expandShapeSizes#j,
236 /// for j in group excluding reassIdx#i)
237 /// Where reassIdx#i is the reassociation index at index i in \p groupId.
238 ///
239 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
240 ///
241 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
242 /// this is not possible because this function uses the Affine dialect and the
243 /// MemRef dialect cannot depend on the Affine dialect.
245 getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
246  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
247  SmallVector<int64_t, 2> reassocGroup =
248  expandShape.getReassociationIndices()[groupId];
249  assert(!reassocGroup.empty() &&
250  "Reassociation group should have at least one dimension");
251 
252  unsigned groupSize = reassocGroup.size();
253  SmallVector<OpFoldResult> expandedSizes(groupSize);
254 
255  uint64_t productOfAllStaticSizes = 1;
256  std::optional<unsigned> dynSizeIdx;
257  MemRefType expandShapeType = expandShape.getResultType();
258 
259  // Fill up all the statically known sizes.
260  for (unsigned i = 0; i < groupSize; ++i) {
261  uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
262  if (ShapedType::isDynamic(dimSize)) {
263  assert(!dynSizeIdx && "There must be at most one dynamic size per group");
264  dynSizeIdx = i;
265  continue;
266  }
267  productOfAllStaticSizes *= dimSize;
268  expandedSizes[i] = builder.getIndexAttr(dimSize);
269  }
270 
271  // Compute the dynamic size using the original size and all the other known
272  // static sizes:
273  // expandSize = origSize / productOfAllStaticSizes.
274  if (dynSizeIdx) {
275  AffineExpr s0 = builder.getAffineSymbolExpr(0);
276  expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
277  builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
278  origSizes[groupId]);
279  }
280 
281  return expandedSizes;
282 }
283 
284 /// Compute the expanded strides of the given \p expandShape for the
285 /// \p groupId-th reassociation group.
286 /// \p origStrides and \p origSizes hold respectively the strides and sizes
287 /// of the source shape as values.
288 /// This is used to compute the strides in cases of dynamic shapes and/or
289 /// dynamic stride for this reassociation group.
290 ///
291 /// strides#i =
292 /// origStrides#reassDim * product(expandShapeSizes#j, for j in
293 /// reassIdx#i+1..reassIdx#i+group.size-1)
294 ///
295 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
296 /// and expandShapeSizes#j is either:
297 /// - The constant size at dimension j, derived directly from the result type of
298 /// the expand_shape op, or
299 /// - An affine expression: baseSizes#reassDim / product of all constant sizes
300 /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
301 /// element.)
302 ///
303 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
304 ///
305 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
306 /// this is not possible because this function uses the Affine dialect and the
307 /// MemRef dialect cannot depend on the Affine dialect.
308 SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
309  OpBuilder &builder,
310  ArrayRef<OpFoldResult> origSizes,
311  ArrayRef<OpFoldResult> origStrides,
312  unsigned groupId) {
313  SmallVector<int64_t, 2> reassocGroup =
314  expandShape.getReassociationIndices()[groupId];
315  assert(!reassocGroup.empty() &&
316  "Reassociation group should have at least one dimension");
317 
318  unsigned groupSize = reassocGroup.size();
319  MemRefType expandShapeType = expandShape.getResultType();
320 
321  std::optional<int64_t> dynSizeIdx;
322 
323  // Fill up the expanded strides, with the information we can deduce from the
324  // resulting shape.
325  uint64_t currentStride = 1;
326  SmallVector<OpFoldResult> expandedStrides(groupSize);
327  for (int i = groupSize - 1; i >= 0; --i) {
328  expandedStrides[i] = builder.getIndexAttr(currentStride);
329  uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
330  if (ShapedType::isDynamic(dimSize)) {
331  assert(!dynSizeIdx && "There must be at most one dynamic size per group");
332  dynSizeIdx = i;
333  continue;
334  }
335 
336  currentStride *= dimSize;
337  }
338 
339  // Collect the statically known information about the original stride.
340  Value source = expandShape.getSrc();
341  auto sourceType = cast<MemRefType>(source.getType());
342  auto [strides, offset] = getStridesAndOffset(sourceType);
343 
344  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
345  ? origStrides[groupId]
346  : builder.getIndexAttr(strides[groupId]);
347 
348  // Apply the original stride to all the strides.
349  int64_t doneStrideIdx = 0;
350  // If we saw a dynamic dimension, we need to fix-up all the strides up to
351  // that dimension with the dynamic size.
352  if (dynSizeIdx) {
353  int64_t productOfAllStaticSizes = currentStride;
354  assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
355  "We shouldn't be able to change dynamicity");
356  OpFoldResult origSize = origSizes[groupId];
357 
358  AffineExpr s0 = builder.getAffineSymbolExpr(0);
359  AffineExpr s1 = builder.getAffineSymbolExpr(1);
360  for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
361  int64_t baseExpandedStride =
362  cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
363  .getInt();
364  expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
365  builder, expandShape.getLoc(),
366  (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
367  {origSize, origStride});
368  }
369  }
370 
371  // Now apply the origStride to the remaining dimensions.
372  AffineExpr s0 = builder.getAffineSymbolExpr(0);
373  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
374  int64_t baseExpandedStride =
375  cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
376  .getInt();
377  expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
378  builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
379  }
380 
381  return expandedStrides;
382 }
383 
384 /// Produce an OpFoldResult object with \p builder at \p loc representing
385 /// `prod(valueOrConstant#i, for i in {indices})`,
386 /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
387 /// values[i] otherwise.
388 ///
389 /// \pre for all index in indices: index < values.size()
390 /// \pre for all index in indices: index < maybeConstants.size()
391 static OpFoldResult
392 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
393  ArrayRef<int64_t> maybeConstants,
394  ArrayRef<OpFoldResult> values,
395  llvm::function_ref<bool(int64_t)> isDynamic) {
396  AffineExpr productOfValues = builder.getAffineConstantExpr(1);
397  SmallVector<OpFoldResult> inputValues;
398  unsigned numberOfSymbols = 0;
399  unsigned groupSize = indices.size();
400  for (unsigned i = 0; i < groupSize; ++i) {
401  productOfValues =
402  productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
403  unsigned srcIdx = indices[i];
404  int64_t maybeConstant = maybeConstants[srcIdx];
405 
406  inputValues.push_back(isDynamic(maybeConstant)
407  ? values[srcIdx]
408  : builder.getIndexAttr(maybeConstant));
409  }
410 
411  return makeComposedFoldedAffineApply(builder, loc, productOfValues,
412  inputValues);
413 }
414 
415 /// Compute the collapsed size of the given \p collpaseShape for the
416 /// \p groupId-th reassociation group.
417 /// \p origSizes hold the sizes of the source shape as values.
418 /// This is used to compute the new sizes in cases of dynamic shapes.
419 ///
420 /// Conceptually this helper function computes:
421 /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
422 ///
423 /// \post result.size() == 1, in other words, each group collapse to one
424 /// dimension.
425 ///
426 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
427 /// this is not possible because this function uses the Affine dialect and the
428 /// MemRef dialect cannot depend on the Affine dialect.
430 getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
431  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
432  SmallVector<OpFoldResult> collapsedSize;
433 
434  MemRefType collapseShapeType = collapseShape.getResultType();
435 
436  uint64_t size = collapseShapeType.getDimSize(groupId);
437  if (!ShapedType::isDynamic(size)) {
438  collapsedSize.push_back(builder.getIndexAttr(size));
439  return collapsedSize;
440  }
441 
442  // We are dealing with a dynamic size.
443  // Build the affine expr of the product of the original sizes involved in that
444  // group.
445  Value source = collapseShape.getSrc();
446  auto sourceType = cast<MemRefType>(source.getType());
447 
448  SmallVector<int64_t, 2> reassocGroup =
449  collapseShape.getReassociationIndices()[groupId];
450 
451  collapsedSize.push_back(getProductOfValues(
452  reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
453  origSizes, ShapedType::isDynamic));
454 
455  return collapsedSize;
456 }
457 
458 /// Compute the collapsed stride of the given \p collpaseShape for the
459 /// \p groupId-th reassociation group.
460 /// \p origStrides and \p origSizes hold respectively the strides and sizes
461 /// of the source shape as values.
462 /// This is used to compute the strides in cases of dynamic shapes and/or
463 /// dynamic stride for this reassociation group.
464 ///
465 /// Conceptually this helper function returns the stride of the inner most
466 /// dimension of that group in the original shape.
467 ///
468 /// \post result.size() == 1, in other words, each group collapse to one
469 /// dimension.
471 getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
472  ArrayRef<OpFoldResult> origSizes,
473  ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
474  SmallVector<int64_t, 2> reassocGroup =
475  collapseShape.getReassociationIndices()[groupId];
476  assert(!reassocGroup.empty() &&
477  "Reassociation group should have at least one dimension");
478 
479  Value source = collapseShape.getSrc();
480  auto sourceType = cast<MemRefType>(source.getType());
481 
482  auto [strides, offset] = getStridesAndOffset(sourceType);
483 
484  SmallVector<OpFoldResult> groupStrides;
485  ArrayRef<int64_t> srcShape = sourceType.getShape();
486  for (int64_t currentDim : reassocGroup) {
487  // Skip size-of-1 dimensions, since right now their strides may be
488  // meaningless.
489  // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
490  // they are truly contiguous. When they are truly contiguous, we shouldn't
491  // need to skip them.
492  if (srcShape[currentDim] == 1)
493  continue;
494 
495  int64_t currentStride = strides[currentDim];
496  groupStrides.push_back(ShapedType::isDynamic(currentStride)
497  ? origStrides[currentDim]
498  : builder.getIndexAttr(currentStride));
499  }
500  if (groupStrides.empty()) {
501  // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
502  // but we still have to make the type system happy.
503  MemRefType collapsedType = collapseShape.getResultType();
504  auto [collapsedStrides, collapsedOffset] =
505  getStridesAndOffset(collapsedType);
506  int64_t finalStride = collapsedStrides[groupId];
507  if (ShapedType::isDynamic(finalStride)) {
508  // Look for a dynamic stride. At this point we don't know which one is
509  // desired, but they are all equally good/bad.
510  for (int64_t currentDim : reassocGroup) {
511  assert(srcShape[currentDim] == 1 &&
512  "We should be dealing with 1x1x...x1");
513 
514  if (ShapedType::isDynamic(strides[currentDim]))
515  return {origStrides[currentDim]};
516  }
517  llvm_unreachable("We should have found a dynamic stride");
518  }
519  return {builder.getIndexAttr(finalStride)};
520  }
521 
522  // For the general case, we just want the minimum stride
523  // since the collapsed dimensions are contiguous.
524  auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
525  builder.getContext());
526  return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
527  groupStrides)};
528 }
529 /// Replace `baseBuffer, offset, sizes, strides =
530 /// extract_strided_metadata(reshapeLike(memref))`
531 /// With
532 ///
533 /// \verbatim
534 /// baseBuffer, offset, baseSizes, baseStrides =
535 /// extract_strided_metadata(memref)
536 /// sizes = getReshapedSizes(reshapeLike)
537 /// strides = getReshapedStrides(reshapeLike)
538 /// \endverbatim
539 ///
540 ///
541 /// Notice that `baseBuffer` and `offset` are unchanged.
542 ///
543 /// In other words, get rid of the expand_shape in that expression and
544 /// materialize its effects on the sizes and the strides using affine apply.
545 template <typename ReassociativeReshapeLikeOp,
546  SmallVector<OpFoldResult> (*getReshapedSizes)(
547  ReassociativeReshapeLikeOp, OpBuilder &,
548  ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
549  SmallVector<OpFoldResult> (*getReshapedStrides)(
550  ReassociativeReshapeLikeOp, OpBuilder &,
551  ArrayRef<OpFoldResult> /*origSizes*/,
552  ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
553 struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
554 public:
556 
557  LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
558  PatternRewriter &rewriter) const override {
559  // Build a plain extract_strided_metadata(memref) from
560  // extract_strided_metadata(reassociative_reshape_like(memref)).
561  Location origLoc = reshape.getLoc();
562  Value source = reshape.getSrc();
563  auto sourceType = cast<MemRefType>(source.getType());
564  unsigned sourceRank = sourceType.getRank();
565 
566  auto newExtractStridedMetadata =
567  rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
568 
569  // Collect statically known information.
570  auto [strides, offset] = getStridesAndOffset(sourceType);
571  MemRefType reshapeType = reshape.getResultType();
572  unsigned reshapeRank = reshapeType.getRank();
573 
574  OpFoldResult offsetOfr =
575  ShapedType::isDynamic(offset)
576  ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
577  : rewriter.getIndexAttr(offset);
578 
579  // Get the special case of 0-D out of the way.
580  if (sourceRank == 0) {
581  SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
582  auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
583  origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
584  offsetOfr, /*sizes=*/ones, /*strides=*/ones);
585  rewriter.replaceOp(reshape, memrefDesc.getResult());
586  return success();
587  }
588 
589  SmallVector<OpFoldResult> finalSizes;
590  finalSizes.reserve(reshapeRank);
591  SmallVector<OpFoldResult> finalStrides;
592  finalStrides.reserve(reshapeRank);
593 
594  // Compute the reshaped strides and sizes from the base strides and sizes.
595  SmallVector<OpFoldResult> origSizes =
596  getAsOpFoldResult(newExtractStridedMetadata.getSizes());
597  SmallVector<OpFoldResult> origStrides =
598  getAsOpFoldResult(newExtractStridedMetadata.getStrides());
599  unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
600  for (; idx != endIdx; ++idx) {
601  SmallVector<OpFoldResult> reshapedSizes =
602  getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
603  SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
604  reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
605 
606  unsigned groupSize = reshapedSizes.size();
607  for (unsigned i = 0; i < groupSize; ++i) {
608  finalSizes.push_back(reshapedSizes[i]);
609  finalStrides.push_back(reshapedStrides[i]);
610  }
611  }
612  assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
613  (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
614  "We should have visited all the input dimensions");
615  assert(finalSizes.size() == reshapeRank &&
616  "We should have populated all the values");
617  auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
618  origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
619  offsetOfr, finalSizes, finalStrides);
620  rewriter.replaceOp(reshape, memrefDesc.getResult());
621  return success();
622  }
623 };
624 
625 /// Replace `base, offset, sizes, strides =
626 /// extract_strided_metadata(allocLikeOp)`
627 ///
628 /// With
629 ///
630 /// ```
631 /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
632 /// offset = 0
633 /// sizes = allocSizes
634 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
635 /// ```
636 ///
637 /// The transformation only applies if the allocLikeOp has been normalized.
638 /// In other words, the affine_map must be an identity.
639 template <typename AllocLikeOp>
640 struct ExtractStridedMetadataOpAllocFolder
642 public:
644 
645  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
646  PatternRewriter &rewriter) const override {
647  auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
648  if (!allocLikeOp)
649  return failure();
650 
651  auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
652  if (!memRefType.getLayout().isIdentity())
653  return rewriter.notifyMatchFailure(
654  allocLikeOp, "alloc-like operations should have been normalized");
655 
656  Location loc = op.getLoc();
657  int rank = memRefType.getRank();
658 
659  // Collect the sizes.
660  ValueRange dynamic = allocLikeOp.getDynamicSizes();
662  sizes.reserve(rank);
663  unsigned dynamicPos = 0;
664  for (int64_t size : memRefType.getShape()) {
665  if (ShapedType::isDynamic(size))
666  sizes.push_back(dynamic[dynamicPos++]);
667  else
668  sizes.push_back(rewriter.getIndexAttr(size));
669  }
670 
671  // Strides (just creates identity strides).
672  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
673  AffineExpr expr = rewriter.getAffineConstantExpr(1);
674  unsigned symbolNumber = 0;
675  for (int i = rank - 2; i >= 0; --i) {
676  expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++);
677  assert(i + 1 + symbolNumber == sizes.size() &&
678  "The ArrayRef should encompass the last #symbolNumber sizes");
679  ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
680  strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
681  sizesInvolvedInStride);
682  }
683 
684  // Put all the values together to replace the results.
685  SmallVector<Value> results;
686  results.reserve(rank * 2 + 2);
687 
688  auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
689  int64_t offset = 0;
690  if (op.getBaseBuffer().use_empty()) {
691  results.push_back(nullptr);
692  } else {
693  if (allocLikeOp.getType() == baseBufferType)
694  results.push_back(allocLikeOp);
695  else
696  results.push_back(rewriter.create<memref::ReinterpretCastOp>(
697  loc, baseBufferType, allocLikeOp, offset,
698  /*sizes=*/ArrayRef<int64_t>(),
699  /*strides=*/ArrayRef<int64_t>()));
700  }
701 
702  // Offset.
703  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
704 
705  for (OpFoldResult size : sizes)
706  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
707 
708  for (OpFoldResult stride : strides)
709  results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
710 
711  rewriter.replaceOp(op, results);
712  return success();
713  }
714 };
715 
716 /// Replace `base, offset, sizes, strides =
717 /// extract_strided_metadata(get_global)`
718 ///
719 /// With
720 ///
721 /// ```
722 /// base = reinterpret_cast get_global to a flat memref<eltTy>
723 /// offset = 0
724 /// sizes = allocSizes
725 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
726 /// ```
727 ///
728 /// It is expected that the memref.get_global op has static shapes
729 /// and identity affine_map for the layout.
730 struct ExtractStridedMetadataOpGetGlobalFolder
731  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
732 public:
734 
735  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
736  PatternRewriter &rewriter) const override {
737  auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
738  if (!getGlobalOp)
739  return failure();
740 
741  auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
742  if (!memRefType.getLayout().isIdentity()) {
743  return rewriter.notifyMatchFailure(
744  getGlobalOp,
745  "get-global operation result should have been normalized");
746  }
747 
748  Location loc = op.getLoc();
749  int rank = memRefType.getRank();
750 
751  // Collect the sizes.
752  ArrayRef<int64_t> sizes = memRefType.getShape();
753  assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
754  "unexpected dynamic shape for result of `memref.get_global` op");
755 
756  // Strides (just creates identity strides).
758 
759  // Put all the values together to replace the results.
760  SmallVector<Value> results;
761  results.reserve(rank * 2 + 2);
762 
763  auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
764  int64_t offset = 0;
765  if (getGlobalOp.getType() == baseBufferType)
766  results.push_back(getGlobalOp);
767  else
768  results.push_back(rewriter.create<memref::ReinterpretCastOp>(
769  loc, baseBufferType, getGlobalOp, offset,
770  /*sizes=*/ArrayRef<int64_t>(),
771  /*strides=*/ArrayRef<int64_t>()));
772 
773  // Offset.
774  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
775 
776  for (auto size : sizes)
777  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
778 
779  for (auto stride : strides)
780  results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
781 
782  rewriter.replaceOp(op, results);
783  return success();
784  }
785 };
786 
787 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
788 /// source of the ViewLikeOp.
789 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
790  : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
792 
794  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
795  PatternRewriter &rewriter) const override {
796  auto viewLikeOp =
797  extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
798  if (!viewLikeOp)
799  return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
800  rewriter.updateRootInPlace(extractOp, [&]() {
801  extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
802  });
803  return success();
804  }
805 };
806 
807 /// Replace `base, offset, sizes, strides =
808 /// extract_strided_metadata(
809 /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
810 /// With
811 /// ```
812 /// base, ... = extract_strided_metadata(src)
813 /// offset = srcOffset
814 /// sizes = srcSizes
815 /// strides = srcStrides
816 /// ```
817 ///
818 /// In other words, consume the `reinterpret_cast` and apply its effects
819 /// on the offset, sizes, and strides.
820 class ExtractStridedMetadataOpReinterpretCastFolder
821  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
823 
825  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
826  PatternRewriter &rewriter) const override {
827  auto reinterpretCastOp = extractStridedMetadataOp.getSource()
828  .getDefiningOp<memref::ReinterpretCastOp>();
829  if (!reinterpretCastOp)
830  return failure();
831 
832  Location loc = extractStridedMetadataOp.getLoc();
833  // Check if the source is suitable for extract_strided_metadata.
834  SmallVector<Type> inferredReturnTypes;
835  if (failed(extractStridedMetadataOp.inferReturnTypes(
836  rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
837  /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
838  inferredReturnTypes)))
839  return rewriter.notifyMatchFailure(
840  reinterpretCastOp, "reinterpret_cast source's type is incompatible");
841 
842  auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
843  unsigned rank = memrefType.getRank();
845  results.resize_for_overwrite(rank * 2 + 2);
846 
847  auto newExtractStridedMetadata =
848  rewriter.create<memref::ExtractStridedMetadataOp>(
849  loc, reinterpretCastOp.getSource());
850 
851  // Register the base_buffer.
852  results[0] = newExtractStridedMetadata.getBaseBuffer();
853 
854  // Register the new offset.
855  results[1] = getValueOrCreateConstantIndexOp(
856  rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
857 
858  const unsigned sizeStartIdx = 2;
859  const unsigned strideStartIdx = sizeStartIdx + rank;
860 
861  SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
862  SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
863  for (unsigned i = 0; i < rank; ++i) {
864  results[sizeStartIdx + i] = sizes[i];
865  results[strideStartIdx + i] = strides[i];
866  }
867  rewriter.replaceOp(extractStridedMetadataOp,
868  getValueOrCreateConstantIndexOp(rewriter, loc, results));
869  return success();
870  }
871 };
872 
873 /// Replace `base, offset, sizes, strides =
874 /// extract_strided_metadata(
875 /// cast(src) to dstTy)`
876 /// With
877 /// ```
878 /// base, ... = extract_strided_metadata(src)
879 /// offset = !dstTy.srcOffset.isDynamic()
880 /// ? dstTy.srcOffset
881 /// : extract_strided_metadata(src).offset
882 /// sizes = for each srcSize in dstTy.srcSizes:
883 /// !srcSize.isDynamic()
884 /// ? srcSize
885 // : extract_strided_metadata(src).sizes[i]
886 /// strides = for each srcStride in dstTy.srcStrides:
887 /// !srcStrides.isDynamic()
888 /// ? srcStrides
889 /// : extract_strided_metadata(src).strides[i]
890 /// ```
891 ///
892 /// In other words, consume the `cast` and apply its effects
893 /// on the offset, sizes, and strides or compute them directly from `src`.
894 class ExtractStridedMetadataOpCastFolder
895  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
897 
899  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
900  PatternRewriter &rewriter) const override {
901  Value source = extractStridedMetadataOp.getSource();
902  auto castOp = source.getDefiningOp<memref::CastOp>();
903  if (!castOp)
904  return failure();
905 
906  Location loc = extractStridedMetadataOp.getLoc();
907  // Check if the source is suitable for extract_strided_metadata.
908  SmallVector<Type> inferredReturnTypes;
909  if (failed(extractStridedMetadataOp.inferReturnTypes(
910  rewriter.getContext(), loc, {castOp.getSource()},
911  /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
912  inferredReturnTypes)))
913  return rewriter.notifyMatchFailure(castOp,
914  "cast source's type is incompatible");
915 
916  auto memrefType = cast<MemRefType>(source.getType());
917  unsigned rank = memrefType.getRank();
919  results.resize_for_overwrite(rank * 2 + 2);
920 
921  auto newExtractStridedMetadata =
922  rewriter.create<memref::ExtractStridedMetadataOp>(loc,
923  castOp.getSource());
924 
925  // Register the base_buffer.
926  results[0] = newExtractStridedMetadata.getBaseBuffer();
927 
928  auto getConstantOrValue = [&rewriter](int64_t constant,
929  OpFoldResult ofr) -> OpFoldResult {
930  return !ShapedType::isDynamic(constant)
931  ? OpFoldResult(rewriter.getIndexAttr(constant))
932  : ofr;
933  };
934 
935  auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
936  assert(sourceStrides.size() == rank && "unexpected number of strides");
937 
938  // Register the new offset.
939  results[1] =
940  getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
941 
942  const unsigned sizeStartIdx = 2;
943  const unsigned strideStartIdx = sizeStartIdx + rank;
944  ArrayRef<int64_t> sourceSizes = memrefType.getShape();
945 
946  SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
947  SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
948  for (unsigned i = 0; i < rank; ++i) {
949  results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
950  results[strideStartIdx + i] =
951  getConstantOrValue(sourceStrides[i], strides[i]);
952  }
953  rewriter.replaceOp(extractStridedMetadataOp,
954  getValueOrCreateConstantIndexOp(rewriter, loc, results));
955  return success();
956  }
957 };
958 
959 /// Replace `base, offset =
960 /// extract_strided_metadata(extract_strided_metadata(src)#0)`
961 /// With
962 /// ```
963 /// base, ... = extract_strided_metadata(src)
964 /// offset = 0
965 /// ```
966 class ExtractStridedMetadataOpExtractStridedMetadataFolder
967  : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
969 
971  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
972  PatternRewriter &rewriter) const override {
973  auto sourceExtractStridedMetadataOp =
974  extractStridedMetadataOp.getSource()
975  .getDefiningOp<memref::ExtractStridedMetadataOp>();
976  if (!sourceExtractStridedMetadataOp)
977  return failure();
978  Location loc = extractStridedMetadataOp.getLoc();
979  rewriter.replaceOp(extractStridedMetadataOp,
980  {sourceExtractStridedMetadataOp.getBaseBuffer(),
982  rewriter, loc, rewriter.getIndexAttr(0))});
983  return success();
984  }
985 };
986 } // namespace
987 
989  RewritePatternSet &patterns) {
990  patterns.add<SubviewFolder,
991  ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
992  getExpandedStrides>,
993  ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
994  getCollapsedStride>,
995  ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
996  ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
997  ExtractStridedMetadataOpGetGlobalFolder,
998  RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
999  ExtractStridedMetadataOpReinterpretCastFolder,
1000  ExtractStridedMetadataOpCastFolder,
1001  ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1002  patterns.getContext());
1003 }
1004 
1006  RewritePatternSet &patterns) {
1007  patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1008  ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1009  ExtractStridedMetadataOpGetGlobalFolder,
1010  ExtractStridedMetadataOpSubviewFolder,
1011  RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1012  ExtractStridedMetadataOpReinterpretCastFolder,
1013  ExtractStridedMetadataOpCastFolder,
1014  ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1015  patterns.getContext());
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // Pass registration
1020 //===----------------------------------------------------------------------===//
1021 
1022 namespace {
1023 
1024 struct ExpandStridedMetadataPass final
1025  : public memref::impl::ExpandStridedMetadataBase<
1026  ExpandStridedMetadataPass> {
1027  void runOnOperation() override;
1028 };
1029 
1030 } // namespace
1031 
1032 void ExpandStridedMetadataPass::runOnOperation() {
1033  RewritePatternSet patterns(&getContext());
1035  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1036 }
1037 
1039  return std::make_unique<ExpandStridedMetadataPass>();
1040 }
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:867
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:312
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:831
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1276
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
std::unique_ptr< Pass > createExpandStridedMetadataPass()
Creates an operation pass to expand some memref operation into easier to reason about operations.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:353
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361