MLIR  22.0.0git
ComposeSubView.cpp
Go to the documentation of this file.
1 //===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for combining composed subview ops (i.e. subview
10 // of a subview becomes a single subview).
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 // Replaces a subview of a subview with a single subview(both static and dynamic
27 // offsets are supported).
28 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
30 
31  LogicalResult matchAndRewrite(memref::SubViewOp op,
32  PatternRewriter &rewriter) const override {
33  // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
34  // produces the input of the op we're rewriting (for 'SubViewOp' the input
35  // is called the "source" value). We can only combine them if both 'op' and
36  // 'sourceOp' are 'SubViewOp'.
37  auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
38  if (!sourceOp)
39  return failure();
40 
41  // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
42  // output memref that are statically known to be equal to 1. We do not
43  // allow 'sourceOp' to be a rank-reducing subview because then our two
44  // 'SubViewOp's would have different numbers of offset/size/stride
45  // parameters (just difficult to deal with, not impossible if we end up
46  // needing it).
47  if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
48  return failure();
49  }
50 
51  // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
52  SmallVector<OpFoldResult> offsets, sizes, strides,
53  opStrides = op.getMixedStrides(),
54  sourceStrides = sourceOp.getMixedStrides();
55 
56  // The output stride in each dimension is equal to the product of the
57  // dimensions corresponding to source and op.
58  int64_t sourceStrideValue;
59  for (auto &&[opStride, sourceStride] :
60  llvm::zip(opStrides, sourceStrides)) {
61  Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
62  Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
63  if (!opStrideAttr || !sourceStrideAttr)
64  return failure();
65  sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
66  strides.push_back(rewriter.getI64IntegerAttr(
67  cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
68  }
69 
70  // The rules for calculating the new offsets and sizes are:
71  // * Multiple subview offsets for a given dimension compose additively.
72  // ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
73  // m + n * k")
74  // * Multiple sizes for a given dimension compose by taking the size of the
75  // final subview and ignoring the rest. ("Take m values" followed by "Take
76  // n values" == "Take n values") This size must also be the smallest one
77  // by definition (a subview needs to be the same size as or smaller than
78  // its source along each dimension; presumably subviews that are larger
79  // than their sources are disallowed by validation).
80  for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
81  llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
82  sourceOp.getMixedStrides(), op.getMixedSizes())) {
83  sizes.push_back(opSize);
84  Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
85  sourceOffsetAttr =
86  llvm::dyn_cast_if_present<Attribute>(sourceOffset),
87  sourceStrideAttr =
88  llvm::dyn_cast_if_present<Attribute>(sourceStride);
89  if (opOffsetAttr && sourceOffsetAttr) {
90 
91  // If both offsets are static we can simply calculate the combined
92  // offset statically.
93  offsets.push_back(rewriter.getI64IntegerAttr(
94  cast<IntegerAttr>(opOffsetAttr).getInt() *
95  cast<IntegerAttr>(sourceStrideAttr).getInt() +
96  cast<IntegerAttr>(sourceOffsetAttr).getInt()));
97  } else {
98  AffineExpr expr;
99  SmallVector<Value> affineApplyOperands;
100 
101  // Make 'expr' add 'sourceOffset'.
102  if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
103  expr =
104  rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
105  } else {
106  expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
107  affineApplyOperands.push_back(cast<Value>(sourceOffset));
108  }
109 
110  // Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
111  // result.
112  if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
113  expr = expr + cast<IntegerAttr>(attr).getInt() *
114  cast<IntegerAttr>(sourceStrideAttr).getInt();
115  } else {
116  expr =
117  expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
118  cast<IntegerAttr>(sourceStrideAttr).getInt();
119  affineApplyOperands.push_back(cast<Value>(opOffset));
120  }
121 
122  AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
123  Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
124  affineApplyOperands);
125  offsets.push_back(result);
126  }
127  }
128 
129  // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
130  // uses it can be removed by a (separate) dead code elimination pass.
131  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
132  op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
133  return success();
134  }
135 };
136 
137 } // namespace
138 
140  MLIRContext *context) {
141  patterns.add<ComposeSubViewOpPattern>(context);
142 }
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateComposeSubViewPatterns(RewritePatternSet &patterns, MLIRContext *context)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319