MLIR 22.0.0git
ExpandStridedMetadata.cpp
Go to the documentation of this file.
1//===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// The pass expands memref operations that modify the metadata of a memref
10/// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11/// In particular, this pass transforms operations into explicit sequence of
12/// operations that model the effect of this operation on the different
13/// metadata. This pass uses affine constructs to materialize these effects.
14//===----------------------------------------------------------------------===//
15
22#include "mlir/IR/AffineMap.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include <optional>
28
29namespace mlir {
30namespace memref {
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33} // namespace memref
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40
41struct StridedMetadata {
42 Value basePtr;
43 OpFoldResult offset;
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
46};
47
48/// From `subview(memref, subOffset, subSizes, subStrides))` compute
49///
50/// \verbatim
51/// baseBuffer, baseOffset, baseSizes, baseStrides =
52/// extract_strided_metadata(memref)
53/// strides#i = baseStrides#i * subStrides#i
54/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55/// sizes = subSizes
56/// \endverbatim
57///
58/// and return {baseBuffer, offset, sizes, strides}
59static FailureOr<StridedMetadata>
60resolveSubviewStridedMetadata(RewriterBase &rewriter,
61 memref::SubViewOp subview) {
62 // Build a plain extract_strided_metadata(memref) from subview(memref).
63 Location origLoc = subview.getLoc();
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.getType());
66 unsigned sourceRank = sourceType.getRank();
67
68 auto newExtractStridedMetadata =
69 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
70
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
72#ifndef NDEBUG
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
74#endif // NDEBUG
75
76 // Compute the new strides and offset from the base strides and offset:
77 // newStride#i = baseStride#i * subStride#i
78 // offset = baseOffset + sum(subOffsets#i * newStrides#i)
80 SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
81 auto origStrides = newExtractStridedMetadata.getStrides();
82
83 // Hold the affine symbols and values for the computation of the offset.
84 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
85 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
86
87 bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
88 AffineExpr expr = symbols.front();
89 values[0] = ShapedType::isDynamic(sourceOffset)
90 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
91 : rewriter.getIndexAttr(sourceOffset);
92 SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
93
94 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
95 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
96 for (unsigned i = 0; i < sourceRank; ++i) {
97 // Compute the stride.
98 OpFoldResult origStride =
99 ShapedType::isDynamic(sourceStrides[i])
100 ? origStrides[i]
101 : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
102 strides.push_back(makeComposedFoldedAffineApply(
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
104
105 // Build up the computation of the offset.
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
112 }
113
114 // Compute the offset.
115 OpFoldResult finalOffset =
116 makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
117#ifndef NDEBUG
118 // Assert that the computed offset matches the offset of the result type of
119 // the subview op (if both are static).
120 std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset);
121 if (computedOffset && ShapedType::isStatic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
124#endif // NDEBUG
125
126 // The final result is <baseBuffer, offset, sizes, strides>.
127 // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
128 // the values.
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
131
132 // The sizes of the final type are defined directly by the input sizes of
133 // the subview.
134 // Moreover subviews can drop some dimensions, some strides and sizes may
135 // not end up in the final <base, offset, sizes, strides> value that we are
136 // replacing.
137 // Do the filtering here.
138 SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
140
141 SmallVector<OpFoldResult> finalSizes;
142 finalSizes.reserve(subRank);
143
144 SmallVector<OpFoldResult> finalStrides;
145 finalStrides.reserve(subRank);
146
147#ifndef NDEBUG
148 // Iteration variable for result dimensions of the subview op.
149 int64_t j = 0;
150#endif // NDEBUG
151 for (unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(i))
153 continue;
154
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
157#ifndef NDEBUG
158 // Assert that the computed stride matches the stride of the result type of
159 // the subview op (if both are static).
160 std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161 if (computedStride && ShapedType::isStatic(resultStrides[j]))
162 assert(*computedStride == resultStrides[j] &&
163 "mismatch between computed stride and result type stride");
164 ++j;
165#endif // NDEBUG
166 }
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
171}
172
173/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
174/// With
175///
176/// \verbatim
177/// baseBuffer, baseOffset, baseSizes, baseStrides =
178/// extract_strided_metadata(memref)
179/// strides#i = baseStrides#i * subSizes#i
180/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
181/// sizes = subSizes
182/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
183/// \endverbatim
184///
185/// In other words, get rid of the subview in that expression and canonicalize
186/// on its effects on the offset, the sizes, and the strides using affine.apply.
187struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
188public:
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
190
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter) const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
196 return rewriter.notifyMatchFailure(subview,
197 "failed to resolve subview metadata");
198 }
199
200 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
204 return success();
205 }
207
208/// Pattern to replace `extract_strided_metadata(subview)`
209/// With
210///
211/// \verbatim
212/// baseBuffer, baseOffset, baseSizes, baseStrides =
213/// extract_strided_metadata(memref)
214/// strides#i = baseStrides#i * subSizes#i
215/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
216/// sizes = subSizes
217/// \verbatim
218///
219/// with `baseBuffer`, `offset`, `sizes` and `strides` being
220/// the replacements for the original `extract_strided_metadata`.
221struct ExtractStridedMetadataOpSubviewFolder
222 : OpRewritePattern<memref::ExtractStridedMetadataOp> {
224
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
226 PatternRewriter &rewriter) const override {
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
228 if (!subviewOp)
229 return failure();
230
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
234 return rewriter.notifyMatchFailure(
235 op, "failed to resolve metadata in terms of source subview op");
236 }
237 Location loc = subviewOp.getLoc();
238 SmallVector<Value> results;
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
241 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
242 stridedMetadata->offset));
243 results.append(
244 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
245 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
246 stridedMetadata->strides));
247 rewriter.replaceOp(op, results);
248
249 return success();
250 }
251};
252
253/// Compute the expanded sizes of the given \p expandShape for the
254/// \p groupId-th reassociation group.
255/// \p origSizes hold the sizes of the source shape as values.
256/// This is used to compute the new sizes in cases of dynamic shapes.
257///
258/// sizes#i =
259/// baseSizes#groupId / product(expandShapeSizes#j,
260/// for j in group excluding reassIdx#i)
261/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262///
263/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264///
265/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266/// this is not possible because this function uses the Affine dialect and the
267/// MemRef dialect cannot depend on the Affine dialect.
269getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
270 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271 SmallVector<int64_t, 2> reassocGroup =
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
275
276 unsigned groupSize = reassocGroup.size();
277 SmallVector<OpFoldResult> expandedSizes(groupSize);
278
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
282
283 // Fill up all the statically known sizes.
284 for (unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288 dynSizeIdx = i;
289 continue;
290 }
291 productOfAllStaticSizes *= dimSize;
292 expandedSizes[i] = builder.getIndexAttr(dimSize);
293 }
294
295 // Compute the dynamic size using the original size and all the other known
296 // static sizes:
297 // expandSize = origSize / productOfAllStaticSizes.
298 if (dynSizeIdx) {
299 AffineExpr s0 = builder.getAffineSymbolExpr(0);
300 expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301 builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
302 origSizes[groupId]);
303 }
304
305 return expandedSizes;
306}
307
308/// Compute the expanded strides of the given \p expandShape for the
309/// \p groupId-th reassociation group.
310/// \p origStrides and \p origSizes hold respectively the strides and sizes
311/// of the source shape as values.
312/// This is used to compute the strides in cases of dynamic shapes and/or
313/// dynamic stride for this reassociation group.
314///
315/// strides#i =
316/// origStrides#reassDim * product(expandShapeSizes#j, for j in
317/// reassIdx#i+1..reassIdx#i+group.size-1)
318///
319/// Where reassIdx#i is the reassociation index for at index i in \p groupId
320/// and expandShapeSizes#j is either:
321/// - The constant size at dimension j, derived directly from the result type of
322/// the expand_shape op, or
323/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325/// element.)
326///
327/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328///
329/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330/// this is not possible because this function uses the Affine dialect and the
331/// MemRef dialect cannot depend on the Affine dialect.
332SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
333 OpBuilder &builder,
334 ArrayRef<OpFoldResult> origSizes,
335 ArrayRef<OpFoldResult> origStrides,
336 unsigned groupId) {
337 SmallVector<int64_t, 2> reassocGroup =
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
341
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
344
345 std::optional<int64_t> dynSizeIdx;
346
347 // Fill up the expanded strides, with the information we can deduce from the
348 // resulting shape.
349 uint64_t currentStride = 1;
350 SmallVector<OpFoldResult> expandedStrides(groupSize);
351 for (int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356 dynSizeIdx = i;
357 continue;
358 }
359
360 currentStride *= dimSize;
361 }
362
363 // Collect the statically known information about the original stride.
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
367
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
370 : builder.getIndexAttr(strides[groupId]);
371
372 // Apply the original stride to all the strides.
373 int64_t doneStrideIdx = 0;
374 // If we saw a dynamic dimension, we need to fix-up all the strides up to
375 // that dimension with the dynamic size.
376 if (dynSizeIdx) {
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
380 OpFoldResult origSize = origSizes[groupId];
381
382 AffineExpr s0 = builder.getAffineSymbolExpr(0);
383 AffineExpr s1 = builder.getAffineSymbolExpr(1);
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387 .getInt();
388 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391 {origSize, origStride});
392 }
393 }
394
395 // Now apply the origStride to the remaining dimensions.
396 AffineExpr s0 = builder.getAffineSymbolExpr(0);
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400 .getInt();
401 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
403 }
404
405 return expandedStrides;
406}
407
408/// Produce an OpFoldResult object with \p builder at \p loc representing
409/// `prod(valueOrConstant#i, for i in {indices})`,
410/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411/// values[i] otherwise.
412///
413/// \pre for all index in indices: index < values.size()
414/// \pre for all index in indices: index < maybeConstants.size()
415static OpFoldResult
416getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417 ArrayRef<int64_t> maybeConstants,
419 llvm::function_ref<bool(int64_t)> isDynamic) {
420 AffineExpr productOfValues = builder.getAffineConstantExpr(1);
421 SmallVector<OpFoldResult> inputValues;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (unsigned i = 0; i < groupSize; ++i) {
425 productOfValues =
426 productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
429
430 inputValues.push_back(isDynamic(maybeConstant)
431 ? values[srcIdx]
432 : builder.getIndexAttr(maybeConstant));
433 }
434
435 return makeComposedFoldedAffineApply(builder, loc, productOfValues,
436 inputValues);
437}
438
439/// Compute the collapsed size of the given \p collpaseShape for the
440/// \p groupId-th reassociation group.
441/// \p origSizes hold the sizes of the source shape as values.
442/// This is used to compute the new sizes in cases of dynamic shapes.
443///
444/// Conceptually this helper function computes:
445/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
446///
447/// \post result.size() == 1, in other words, each group collapse to one
448/// dimension.
449///
450/// TODO: Move this utility function directly within CollapseShapeOp. For now,
451/// this is not possible because this function uses the Affine dialect and the
452/// MemRef dialect cannot depend on the Affine dialect.
454getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
455 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456 SmallVector<OpFoldResult> collapsedSize;
457
458 MemRefType collapseShapeType = collapseShape.getResultType();
459
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (ShapedType::isStatic(size)) {
462 collapsedSize.push_back(builder.getIndexAttr(size));
463 return collapsedSize;
464 }
465
466 // We are dealing with a dynamic size.
467 // Build the affine expr of the product of the original sizes involved in that
468 // group.
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.getType());
471
472 SmallVector<int64_t, 2> reassocGroup =
473 collapseShape.getReassociationIndices()[groupId];
474
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
478
479 return collapsedSize;
480}
481
482/// Compute the collapsed stride of the given \p collpaseShape for the
483/// \p groupId-th reassociation group.
484/// \p origStrides and \p origSizes hold respectively the strides and sizes
485/// of the source shape as values.
486/// This is used to compute the strides in cases of dynamic shapes and/or
487/// dynamic stride for this reassociation group.
488///
489/// Conceptually this helper function returns the stride of the inner most
490/// dimension of that group in the original shape.
491///
492/// \post result.size() == 1, in other words, each group collapse to one
493/// dimension.
495getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
496 ArrayRef<OpFoldResult> origSizes,
497 ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498 SmallVector<int64_t, 2> reassocGroup =
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
502
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.getType());
505
506 auto [strides, offset] = sourceType.getStridesAndOffset();
507
508 ArrayRef<int64_t> srcShape = sourceType.getShape();
509
510 OpFoldResult lastValidStride = nullptr;
511 for (int64_t currentDim : reassocGroup) {
512 // Skip size-of-1 dimensions, since right now their strides may be
513 // meaningless.
514 // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
515 // they are truly contiguous. When they are truly contiguous, we shouldn't
516 // need to skip them.
517 if (srcShape[currentDim] == 1)
518 continue;
519
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
523 : builder.getIndexAttr(currentStride);
524 }
525 if (!lastValidStride) {
526 // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
527 // but we still have to make the type system happy.
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
533 // Look for a dynamic stride. At this point we don't know which one is
534 // desired, but they are all equally good/bad.
535 for (int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
538
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
541 }
542 llvm_unreachable("We should have found a dynamic stride");
543 }
544 return {builder.getIndexAttr(finalStride)};
545 }
546
547 return {lastValidStride};
548}
549
550/// From `reshape_like(memref, subSizes, subStrides))` compute
551///
552/// \verbatim
553/// baseBuffer, baseOffset, baseSizes, baseStrides =
554/// extract_strided_metadata(memref)
555/// strides#i = baseStrides#i * subStrides#i
556/// sizes = subSizes
557/// \endverbatim
558///
559/// and return {baseBuffer, baseOffset, sizes, strides}
560template <typename ReassociativeReshapeLikeOp>
561static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
564 ReassociativeReshapeLikeOp, OpBuilder &,
565 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
566 getReshapedSizes,
568 ReassociativeReshapeLikeOp, OpBuilder &,
569 ArrayRef<OpFoldResult> /*origSizes*/,
570 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
571 getReshapedStrides) {
572 // Build a plain extract_strided_metadata(memref) from
573 // extract_strided_metadata(reassociative_reshape_like(memref)).
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.getType());
577 unsigned sourceRank = sourceType.getRank();
578
579 auto newExtractStridedMetadata =
580 memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source);
581
582 // Collect statically known information.
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
586
587 OpFoldResult offsetOfr =
588 ShapedType::isDynamic(offset)
589 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
590 : rewriter.getIndexAttr(offset);
591
592 // Get the special case of 0-D out of the way.
593 if (sourceRank == 0) {
594 SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
596 /*sizes=*/ones, /*strides=*/ones};
597 }
598
599 SmallVector<OpFoldResult> finalSizes;
600 finalSizes.reserve(reshapeRank);
601 SmallVector<OpFoldResult> finalStrides;
602 finalStrides.reserve(reshapeRank);
603
604 // Compute the reshaped strides and sizes from the base strides and sizes.
605 SmallVector<OpFoldResult> origSizes =
606 getAsOpFoldResult(newExtractStridedMetadata.getSizes());
607 SmallVector<OpFoldResult> origStrides =
608 getAsOpFoldResult(newExtractStridedMetadata.getStrides());
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
611 SmallVector<OpFoldResult> reshapedSizes =
612 getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
613 SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
614 reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
615
616 unsigned groupSize = reshapedSizes.size();
617 for (unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
620 }
621 }
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
627
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
630}
631
632/// Replace `baseBuffer, offset, sizes, strides =
633/// extract_strided_metadata(reshapeLike(memref))`
634/// With
635///
636/// \verbatim
637/// baseBuffer, offset, baseSizes, baseStrides =
638/// extract_strided_metadata(memref)
639/// sizes = getReshapedSizes(reshapeLike)
640/// strides = getReshapedStrides(reshapeLike)
641/// \endverbatim
642///
643///
644/// Notice that `baseBuffer` and `offset` are unchanged.
645///
646/// In other words, get rid of the expand_shape in that expression and
647/// materialize its effects on the sizes and the strides using affine apply.
648template <typename ReassociativeReshapeLikeOp,
649 SmallVector<OpFoldResult> (*getReshapedSizes)(
650 ReassociativeReshapeLikeOp, OpBuilder &,
651 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
652 SmallVector<OpFoldResult> (*getReshapedStrides)(
653 ReassociativeReshapeLikeOp, OpBuilder &,
654 ArrayRef<OpFoldResult> /*origSizes*/,
655 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
656struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
657public:
658 using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
659
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
661 PatternRewriter &rewriter) const override {
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (failed(stridedMetadata)) {
666 return rewriter.notifyMatchFailure(reshape,
667 "failed to resolve reshape metadata");
668 }
669
670 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
674 return success();
675 }
676};
677
678/// Pattern to replace `extract_strided_metadata(collapse_shape)`
679/// With
680///
681/// \verbatim
682/// baseBuffer, baseOffset, baseSizes, baseStrides =
683/// extract_strided_metadata(memref)
684/// strides#i = baseStrides#i * subSizes#i
685/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
686/// sizes = subSizes
687/// \verbatim
688///
689/// with `baseBuffer`, `offset`, `sizes` and `strides` being
690/// the replacements for the original `extract_strided_metadata`.
691struct ExtractStridedMetadataOpCollapseShapeFolder
694
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
696 PatternRewriter &rewriter) const override {
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
700 return failure();
701
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (failed(stridedMetadata)) {
706 return rewriter.notifyMatchFailure(
707 op,
708 "failed to resolve metadata in terms of source collapse_shape op");
709 }
710
711 Location loc = collapseShapeOp.getLoc();
712 SmallVector<Value> results;
713 results.push_back(stridedMetadata->basePtr);
714 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
715 stridedMetadata->offset));
716 results.append(
717 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
718 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
719 stridedMetadata->strides));
720 rewriter.replaceOp(op, results);
721 return success();
722 }
723};
724
725/// Pattern to replace `extract_strided_metadata(expand_shape)`
726/// with the results of computing the sizes and strides on the expanded shape
727/// and dividing up dimensions into static and dynamic parts as needed.
728struct ExtractStridedMetadataOpExpandShapeFolder
731
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
733 PatternRewriter &rewriter) const override {
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
735 if (!expandShapeOp)
736 return failure();
737
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (failed(stridedMetadata)) {
742 return rewriter.notifyMatchFailure(
743 op, "failed to resolve metadata in terms of source expand_shape op");
744 }
745
746 Location loc = expandShapeOp.getLoc();
747 SmallVector<Value> results;
748 results.push_back(stridedMetadata->basePtr);
749 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
750 stridedMetadata->offset));
751 results.append(
752 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
753 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
754 stridedMetadata->strides));
755 rewriter.replaceOp(op, results);
756 return success();
757 }
758};
759
760/// Replace `base, offset, sizes, strides =
761/// extract_strided_metadata(allocLikeOp)`
762///
763/// With
764///
765/// ```
766/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
767/// offset = 0
768/// sizes = allocSizes
769/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
770/// ```
771///
772/// The transformation only applies if the allocLikeOp has been normalized.
773/// In other words, the affine_map must be an identity.
774template <typename AllocLikeOp>
775struct ExtractStridedMetadataOpAllocFolder
777public:
778 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
779
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
781 PatternRewriter &rewriter) const override {
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
783 if (!allocLikeOp)
784 return failure();
785
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
788 return rewriter.notifyMatchFailure(
789 allocLikeOp, "alloc-like operations should have been normalized");
790
791 Location loc = op.getLoc();
792 int rank = memRefType.getRank();
793
794 // Collect the sizes.
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
797 sizes.reserve(rank);
798 unsigned dynamicPos = 0;
799 for (int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
802 else
803 sizes.push_back(rewriter.getIndexAttr(size));
804 }
805
806 // Strides (just creates identity strides).
807 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
808 AffineExpr expr = rewriter.getAffineConstantExpr(1);
809 unsigned symbolNumber = 0;
810 for (int i = rank - 2; i >= 0; --i) {
811 expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++);
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
814 ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
815 strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
816 sizesInvolvedInStride);
817 }
818
819 // Put all the values together to replace the results.
820 SmallVector<Value> results;
821 results.reserve(rank * 2 + 2);
822
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
824 int64_t offset = 0;
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(nullptr);
827 } else {
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
830 else
831 results.push_back(memref::ReinterpretCastOp::create(
832 rewriter, loc, baseBufferType, allocLikeOp, offset,
833 /*sizes=*/ArrayRef<int64_t>(),
834 /*strides=*/ArrayRef<int64_t>()));
835 }
836
837 // Offset.
838 results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
839
840 for (OpFoldResult size : sizes)
841 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
842
843 for (OpFoldResult stride : strides)
844 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
845
846 rewriter.replaceOp(op, results);
847 return success();
848 }
849};
850
851/// Replace `base, offset, sizes, strides =
852/// extract_strided_metadata(get_global)`
853///
854/// With
855///
856/// ```
857/// base = reinterpret_cast get_global to a flat memref<eltTy>
858/// offset = 0
859/// sizes = allocSizes
860/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
861/// ```
862///
863/// It is expected that the memref.get_global op has static shapes
864/// and identity affine_map for the layout.
865struct ExtractStridedMetadataOpGetGlobalFolder
866 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
867public:
868 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
869
870 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
871 PatternRewriter &rewriter) const override {
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
873 if (!getGlobalOp)
874 return failure();
875
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
878 return rewriter.notifyMatchFailure(
879 getGlobalOp,
880 "get-global operation result should have been normalized");
881 }
882
883 Location loc = op.getLoc();
884 int rank = memRefType.getRank();
885
886 // Collect the sizes.
887 ArrayRef<int64_t> sizes = memRefType.getShape();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
890
891 // Strides (just creates identity strides).
892 SmallVector<int64_t> strides = computeSuffixProduct(sizes);
893
894 // Put all the values together to replace the results.
895 SmallVector<Value> results;
896 results.reserve(rank * 2 + 2);
897
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
899 int64_t offset = 0;
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
902 else
903 results.push_back(memref::ReinterpretCastOp::create(
904 rewriter, loc, baseBufferType, getGlobalOp, offset,
905 /*sizes=*/ArrayRef<int64_t>(),
906 /*strides=*/ArrayRef<int64_t>()));
907
908 // Offset.
909 results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset));
910
911 for (auto size : sizes)
912 results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size));
913
914 for (auto stride : strides)
915 results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride));
916
917 rewriter.replaceOp(op, results);
918 return success();
919 }
920};
921
922/// Pattern to replace `extract_strided_metadata(assume_alignment)`
923///
924/// With
925/// \verbatim
926/// extract_strided_metadata(memref)
927/// \endverbatim
928///
929/// Since `assume_alignment` is a view-like op that does not modify the
930/// underlying buffer, offset, sizes, or strides, extracting strided metadata
931/// from its result is equivalent to extracting it from its source. This
932/// canonicalization removes the unnecessary indirection.
933struct ExtractStridedMetadataOpAssumeAlignmentFolder
934 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
935public:
936 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
937
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939 PatternRewriter &rewriter) const override {
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
943 return failure();
944
945 rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
946 op, assumeAlignmentOp.getViewSource());
947 return success();
948 }
949};
950
951/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
952/// source of the ViewLikeOp.
953class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
954 : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
956
957 LogicalResult
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
959 PatternRewriter &rewriter) const override {
960 auto viewLikeOp =
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962 if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
963 return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
964 rewriter.modifyOpInPlace(extractOp, [&]() {
965 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
966 });
967 return success();
968 }
969};
970
971/// Replace `base, offset, sizes, strides =
972/// extract_strided_metadata(
973/// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
974/// With
975/// ```
976/// base, ... = extract_strided_metadata(src)
977/// offset = srcOffset
978/// sizes = srcSizes
979/// strides = srcStrides
980/// ```
981///
982/// In other words, consume the `reinterpret_cast` and apply its effects
983/// on the offset, sizes, and strides.
984class ExtractStridedMetadataOpReinterpretCastFolder
985 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
987
988 LogicalResult
989 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
990 PatternRewriter &rewriter) const override {
991 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
992 .getDefiningOp<memref::ReinterpretCastOp>();
993 if (!reinterpretCastOp)
994 return failure();
995
996 Location loc = extractStridedMetadataOp.getLoc();
997 // Check if the source is suitable for extract_strided_metadata.
998 SmallVector<Type> inferredReturnTypes;
999 if (failed(extractStridedMetadataOp.inferReturnTypes(
1000 rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
1001 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1002 inferredReturnTypes)))
1003 return rewriter.notifyMatchFailure(
1004 reinterpretCastOp, "reinterpret_cast source's type is incompatible");
1005
1006 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1007 unsigned rank = memrefType.getRank();
1008 SmallVector<OpFoldResult> results;
1009 results.resize_for_overwrite(rank * 2 + 2);
1010
1011 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1012 rewriter, loc, reinterpretCastOp.getSource());
1013
1014 // Register the base_buffer.
1015 results[0] = newExtractStridedMetadata.getBaseBuffer();
1016
1017 // Register the new offset.
1019 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1020
1021 const unsigned sizeStartIdx = 2;
1022 const unsigned strideStartIdx = sizeStartIdx + rank;
1023
1024 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
1025 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
1026 for (unsigned i = 0; i < rank; ++i) {
1027 results[sizeStartIdx + i] = sizes[i];
1028 results[strideStartIdx + i] = strides[i];
1029 }
1030 rewriter.replaceOp(extractStridedMetadataOp,
1031 getValueOrCreateConstantIndexOp(rewriter, loc, results));
1032 return success();
1033 }
1034};
1035
1036/// Replace `base, offset, sizes, strides = extract_strided_metadata(
1037/// memory_space_cast(src) to dstTy)`
1038/// with
1039/// ```
1040/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
1041/// destBaseTy = type(oldBase) with memory space from destTy
1042/// base = memory_space_cast(oldBase) to destBaseTy
1043/// ```
1044///
1045/// In other words, propagate metadata extraction accross memory space casts.
1046class ExtractStridedMetadataOpMemorySpaceCastFolder
1047 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1049
1050 LogicalResult
1051 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1052 PatternRewriter &rewriter) const override {
1053 Location loc = extractStridedMetadataOp.getLoc();
1054 Value source = extractStridedMetadataOp.getSource();
1055 auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1056 if (!memSpaceCastOp)
1057 return failure();
1058 auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
1059 rewriter, loc, memSpaceCastOp.getSource());
1060 SmallVector<Value> results(newExtractStridedMetadata.getResults());
1061 // As with most other strided metadata rewrite patterns, don't introduce
1062 // a use of the base pointer where non existed. This needs to happen here,
1063 // as opposed to in later dead-code elimination, because these patterns are
1064 // sometimes used during dialect conversion (see EmulateNarrowType, for
1065 // example), so adding spurious usages would cause a pre-legalization value
1066 // to be live that would be dead had this pattern not run.
1067 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1068 auto baseBuffer = results[0];
1069 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1070 MemRefType::Builder newTypeBuilder(baseBufferType);
1071 newTypeBuilder.setMemorySpace(
1072 memSpaceCastOp.getResult().getType().getMemorySpace());
1073 results[0] = memref::MemorySpaceCastOp::create(
1074 rewriter, loc, Type{newTypeBuilder}, baseBuffer);
1075 } else {
1076 results[0] = nullptr;
1077 }
1078 rewriter.replaceOp(extractStridedMetadataOp, results);
1079 return success();
1080 }
1081};
1082
1083/// Replace `base, offset =
1084/// extract_strided_metadata(extract_strided_metadata(src)#0)`
1085/// With
1086/// ```
1087/// base, ... = extract_strided_metadata(src)
1088/// offset = 0
1089/// ```
1090class ExtractStridedMetadataOpExtractStridedMetadataFolder
1091 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1093
1094 LogicalResult
1095 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1096 PatternRewriter &rewriter) const override {
1097 auto sourceExtractStridedMetadataOp =
1098 extractStridedMetadataOp.getSource()
1099 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1100 if (!sourceExtractStridedMetadataOp)
1101 return failure();
1102 Location loc = extractStridedMetadataOp.getLoc();
1103 rewriter.replaceOp(extractStridedMetadataOp,
1104 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1106 rewriter, loc, rewriter.getIndexAttr(0))});
1107 return success();
1108 }
1109};
1110} // namespace
1111
1114 patterns.add<SubviewFolder,
1115 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1116 getExpandedStrides>,
1117 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1118 getCollapsedStride>,
1119 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1120 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1121 ExtractStridedMetadataOpCollapseShapeFolder,
1122 ExtractStridedMetadataOpExpandShapeFolder,
1123 ExtractStridedMetadataOpGetGlobalFolder,
1124 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1125 ExtractStridedMetadataOpReinterpretCastFolder,
1126 ExtractStridedMetadataOpSubviewFolder,
1127 ExtractStridedMetadataOpMemorySpaceCastFolder,
1128 ExtractStridedMetadataOpAssumeAlignmentFolder,
1129 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1130 patterns.getContext());
1131}
1132
1135 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1136 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1137 ExtractStridedMetadataOpCollapseShapeFolder,
1138 ExtractStridedMetadataOpExpandShapeFolder,
1139 ExtractStridedMetadataOpGetGlobalFolder,
1140 ExtractStridedMetadataOpSubviewFolder,
1141 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1142 ExtractStridedMetadataOpReinterpretCastFolder,
1143 ExtractStridedMetadataOpMemorySpaceCastFolder,
1144 ExtractStridedMetadataOpAssumeAlignmentFolder,
1145 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1146 patterns.getContext());
1147}
1148
1149//===----------------------------------------------------------------------===//
1150// Pass registration
1151//===----------------------------------------------------------------------===//
1152
1153namespace {
1154
1155struct ExpandStridedMetadataPass final
1157 ExpandStridedMetadataPass> {
1158 void runOnOperation() override;
1159};
1160
1161} // namespace
1162
1163void ExpandStridedMetadataPass::runOnOperation() {
1164 RewritePatternSet patterns(&getContext());
1166 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1167}
return success()
b getContext())
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for resolving memref.extract_strided_metadata into memref.extract_strided_metadata o...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition AffineExpr.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.