MLIR  21.0.0git
ViewLikeInterface.h
Go to the documentation of this file.
1 //===- ViewLikeInterface.h - View-like operations interface ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the operation interface for view-like operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
14 #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
15 
17 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/PatternMatch.h"
22 
23 namespace mlir {
24 
25 class OffsetSizeAndStrideOpInterface;
26 
27 namespace detail {
28 
29 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
30 
32  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
33  llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
34 
35 /// Helper method to compute the number of dynamic entries of `staticVals`,
36 /// up to `idx`.
37 unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
38  unsigned idx);
39 
40 } // namespace detail
41 } // namespace mlir
42 
43 /// Include the generated interface declarations.
44 #include "mlir/Interfaces/ViewLikeInterface.h.inc"
45 
46 namespace mlir {
47 
48 /// Result for slice bounds verification;
50  /// If set to "true", the slice bounds verification was successful.
51  bool isValid;
52  /// An error message that can be printed during op verification.
53  std::string errorMessage;
54 };
55 
56 /// Verify that the offsets/sizes/strides-style access into the given shape
57 /// is in-bounds. Only static values are verified. If `generateErrorMessage`
58 /// is set to "true", an error message is produced that can be printed by the
59 /// op verifier.
62  ArrayRef<int64_t> staticSizes,
63  ArrayRef<int64_t> staticStrides,
64  bool generateErrorMessage = false);
66  ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
67  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
68  bool generateErrorMessage = false);
69 
70 /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
71 /// constant arguments. This pattern assumes that the op has a suitable builder
72 /// that takes a result type, a "source" operand and mixed offsets, sizes and
73 /// strides.
74 ///
75 /// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
76 /// returns the new result type of the op, based on the new offsets, sizes and
77 /// strides. `CastOpFunc` is used to generate a cast op if the result type of
78 /// the op has changed.
79 template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
81  : public OpRewritePattern<OpType> {
82 public:
84 
85  LogicalResult matchAndRewrite(OpType op,
86  PatternRewriter &rewriter) const override {
87  SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
88  SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
89  SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
90 
91  // No constant operands were folded, just return;
92  if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
93  failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
94  failed(foldDynamicIndexList(mixedStrides)))
95  return failure();
96 
97  // Pattern does not apply if the produced op would not verify.
99  cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
100  mixedSizes, mixedStrides);
101  if (!sliceResult.isValid)
102  return failure();
103 
104  // Compute the new result type.
105  auto resultType =
106  ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
107  if (!resultType)
108  return failure();
109 
110  // Create the new op in canonical form.
111  auto newOp =
112  rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
113  mixedOffsets, mixedSizes, mixedStrides);
114  CastOpFunc()(rewriter, op, newOp);
115 
116  return success();
117  }
118 };
119 
120 /// Printer hooks for custom directive in assemblyFormat.
121 ///
122 /// custom<DynamicIndexList>($values, $integers)
123 /// custom<DynamicIndexList>($values, $integers, type($values))
124 ///
125 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS type
126 /// `I64ArrayAttr`. Print a list where each element is either:
127 /// 1. the static integer value in `integers`, if it's not `kDynamic` or,
128 /// 2. the next value in `values`, otherwise.
129 ///
130 /// If `valueTypes` is provided, the corresponding type of each dynamic value is
131 /// printed. Otherwise, the type is not printed. Each type must match the type
132 /// of the corresponding value in `values`. `valueTypes` is redundant for
133 /// printing as we can retrieve the types from the actual `values`. However,
134 /// `valueTypes` is needed for parsing and we must keep the API symmetric for
135 /// parsing and printing. The type for integer elements is `i64` by default and
136 /// never printed.
137 ///
138 /// Integer indices can also be scalable in the context of scalable vectors,
139 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
140 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
141 /// a scalable index. If `scalableFlags` is empty then assume that all indices
142 /// are non-scalable.
143 ///
144 /// Examples:
145 ///
146 /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
147 /// `values = [%arg0, %arg42]` and
148 /// `valueTypes = [index, index]`
149 /// prints:
150 /// `[%arg0 : index, 7, 42, %arg42 : i32]`
151 ///
152 /// * Input: `integers = [kDynamic, 7, 42, kDynamic]`,
153 /// `values = [%arg0, %arg42]` and
154 /// `valueTypes = []`
155 /// prints:
156 /// `[%arg0, 7, 42, %arg42]`
157 ///
158 /// * Input: `integers = [2, 4, 8]`,
159 /// `values = []` and
160 /// `scalableFlags = [false, true, false]`
161 /// prints:
162 /// `[2, [4], 8]`
163 ///
165  OpAsmPrinter &printer, Operation *op, OperandRange values,
166  ArrayRef<int64_t> integers, ArrayRef<bool> scalableFlags,
167  TypeRange valueTypes = TypeRange(),
170  OpAsmPrinter &printer, Operation *op, OperandRange values,
171  ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
173  return printDynamicIndexList(printer, op, values, integers,
174  /*scalableFlags=*/{}, valueTypes, delimiter);
175 }
176 
177 /// Parser hooks for custom directive in assemblyFormat.
178 ///
179 /// custom<DynamicIndexList>($values, $integers)
180 /// custom<DynamicIndexList>($values, $integers, type($values))
181 ///
182 /// where `values` is of ODS type `Variadic<*>` and `integers` is of ODS
183 /// type `I64ArrayAttr`. Parse a mixed list where each element is either a
184 /// static integer or an SSA value. Fill `integers` with the integer ArrayAttr,
185 /// where `kDynamic` encodes the position of SSA values. Add the parsed SSA
186 /// values to `values` in-order.
187 ///
188 /// If `valueTypes` is provided, fill it with the types corresponding to each
189 /// value in `values`. Otherwise, the caller must handle the types and parsing
190 /// will fail if the type of the value is found (e.g., `[%arg0 : index, 3, %arg1
191 /// : index]`).
192 ///
193 /// Integer indices can also be scalable in the context of scalable vectors,
194 /// denoted by square brackets (e.g., "[2, [4], 8]"). For each value in
195 /// `integers`, the corresponding `bool` in `scalableFlags` encodes whether it's
196 /// a scalable index.
197 ///
198 /// Examples:
199 ///
200 /// * After parsing "[%arg0 : index, 7, 42, %arg42 : i32]":
201 /// 1. `result` is filled with `[kDynamic, 7, 42, kDynamic]`
202 /// 2. `values` is filled with "[%arg0, %arg1]".
203 /// 3. `scalableFlags` is filled with `[false, true, false]`.
204 ///
205 /// * After parsing `[2, [4], 8]`:
206 /// 1. `result` is filled with `[2, 4, 8]`
207 /// 2. `values` is empty.
208 /// 3. `scalableFlags` is filled with `[false, true, false]`.
209 ///
210 ParseResult parseDynamicIndexList(
211  OpAsmParser &parser,
212  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
213  DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
214  SmallVectorImpl<Type> *valueTypes = nullptr,
216 inline ParseResult parseDynamicIndexList(
217  OpAsmParser &parser,
219  DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
221  DenseBoolArrayAttr scalableFlags;
222  return parseDynamicIndexList(parser, values, integers, scalableFlags,
223  valueTypes, delimiter);
224 }
225 
226 /// Verify that a the `values` has as many elements as the number of entries in
227 /// `attr` for which `isDynamic` evaluates to true.
228 LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
229  unsigned expectedNumElements,
230  ArrayRef<int64_t> attr,
231  ValueRange values);
232 
233 } // namespace mlir
234 
235 #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
Include the generated interface declarations.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.