MLIR 22.0.0git
ComposeSubView.cpp
Go to the documentation of this file.
1//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for combining composed subview ops (i.e. subview
10// of a subview becomes a single subview).
11//
12//===----------------------------------------------------------------------===//
13
21
22using namespace mlir;
23
24namespace {
25
26// Replaces a subview of a subview with a single subview(both static and dynamic
27// offsets are supported).
28struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
30
31 LogicalResult matchAndRewrite(memref::SubViewOp op,
32 PatternRewriter &rewriter) const override {
33 // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
34 // produces the input of the op we're rewriting (for 'SubViewOp' the input
35 // is called the "source" value). We can only combine them if both 'op' and
36 // 'sourceOp' are 'SubViewOp'.
37 auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
38 if (!sourceOp)
39 return failure();
40
41 // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
42 // output memref that are statically known to be equal to 1. We do not
43 // allow 'sourceOp' to be a rank-reducing subview because then our two
44 // 'SubViewOp's would have different numbers of offset/size/stride
45 // parameters (just difficult to deal with, not impossible if we end up
46 // needing it).
47 if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
48 return failure();
49 }
50
51 // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
52 SmallVector<OpFoldResult> offsets, sizes, strides,
53 opStrides = op.getMixedStrides(),
54 sourceStrides = sourceOp.getMixedStrides();
55
56 // The output stride in each dimension is equal to the product of the
57 // dimensions corresponding to source and op.
58 int64_t sourceStrideValue;
59 for (auto &&[opStride, sourceStride] :
60 llvm::zip(opStrides, sourceStrides)) {
61 Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
62 Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
63 if (!opStrideAttr || !sourceStrideAttr)
64 return failure();
65 sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
66 strides.push_back(rewriter.getI64IntegerAttr(
67 cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
68 }
69
70 // The rules for calculating the new offsets and sizes are:
71 // * Multiple subview offsets for a given dimension compose additively.
72 // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
73 // m + n * k")
74 // * Multiple sizes for a given dimension compose by taking the size of the
75 // final subview and ignoring the rest. ("Take m values" followed by "Take
76 // n values" == "Take n values") This size must also be the smallest one
77 // by definition (a subview needs to be the same size as or smaller than
78 // its source along each dimension; presumably subviews that are larger
79 // than their sources are disallowed by validation).
80 for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
81 llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
82 sourceOp.getMixedStrides(), op.getMixedSizes())) {
83 sizes.push_back(opSize);
84 Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
85 sourceOffsetAttr =
86 llvm::dyn_cast_if_present<Attribute>(sourceOffset),
87 sourceStrideAttr =
88 llvm::dyn_cast_if_present<Attribute>(sourceStride);
89 if (opOffsetAttr && sourceOffsetAttr) {
90
91 // If both offsets are static we can simply calculate the combined
92 // offset statically.
93 offsets.push_back(rewriter.getI64IntegerAttr(
94 cast<IntegerAttr>(opOffsetAttr).getInt() *
95 cast<IntegerAttr>(sourceStrideAttr).getInt() +
96 cast<IntegerAttr>(sourceOffsetAttr).getInt()));
97 } else {
98 AffineExpr expr;
99 SmallVector<Value> affineApplyOperands;
100
101 // Make 'expr' add 'sourceOffset'.
102 if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
103 expr =
104 rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
105 } else {
106 expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
107 affineApplyOperands.push_back(cast<Value>(sourceOffset));
108 }
109
110 // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
111 // result.
112 if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
113 expr = expr + cast<IntegerAttr>(attr).getInt() *
114 cast<IntegerAttr>(sourceStrideAttr).getInt();
115 } else {
116 expr =
117 expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
118 cast<IntegerAttr>(sourceStrideAttr).getInt();
119 affineApplyOperands.push_back(cast<Value>(opOffset));
120 }
121
122 AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
123 Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
124 affineApplyOperands);
125 offsets.push_back(result);
126 }
127 }
128
129 // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
130 // uses it can be removed by a (separate) dead code elimination pass.
131 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
132 op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
133 return success();
134 }
135};
136
137} // namespace
138
140 MLIRContext *context) {
141 patterns.add<ComposeSubViewOpPattern>(context);
142}
return success()
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateComposeSubViewPatterns(RewritePatternSet &patterns, MLIRContext *context)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...