MLIR 22.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1//===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// ViewLike Interfaces
15//===----------------------------------------------------------------------===//
16
17/// Include the definitions of the loop-like interfaces.
18#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19
21 StringRef name,
22 unsigned numElements,
23 ArrayRef<int64_t> staticVals,
24 ValueRange values) {
25 // Check static and dynamic offsets/sizes/strides does not overflow type.
26 if (staticVals.size() != numElements)
27 return op->emitError("expected ") << numElements << " " << name
28 << " values, got " << staticVals.size();
29 unsigned expectedNumDynamicEntries =
30 llvm::count_if(staticVals, ShapedType::isDynamic);
31 if (values.size() != expectedNumDynamicEntries)
32 return op->emitError("expected ")
33 << expectedNumDynamicEntries << " dynamic " << name << " values";
34 return success();
35}
36
39 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
40 bool generateErrorMessage) {
42 result.isValid = true;
43 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
44 // Nothing to verify for dynamic source dims.
45 if (ShapedType::isDynamic(shape[i]))
46 continue;
47 // Nothing to verify if the offset is dynamic.
48 if (ShapedType::isDynamic(staticOffsets[i]))
49 continue;
50 if (staticOffsets[i] >= shape[i]) {
51 result.errorMessage =
52 std::string("offset ") + std::to_string(i) +
53 " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
54 " >= " + std::to_string(shape[i]);
55 result.isValid = false;
56 return result;
57 }
58 if (ShapedType::isDynamic(staticSizes[i]) ||
59 ShapedType::isDynamic(staticStrides[i]))
60 continue;
61 int64_t lastPos =
62 staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
63 if (lastPos >= shape[i]) {
64 result.errorMessage = std::string("slice along dimension ") +
65 std::to_string(i) +
66 " runs out-of-bounds: " + std::to_string(lastPos) +
67 " >= " + std::to_string(shape[i]);
68 result.isValid = false;
69 return result;
70 }
71 }
72 return result;
73}
74
77 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
78 bool generateErrorMessage) {
79 auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
80 SmallVector<int64_t> staticValues;
81 for (OpFoldResult ofr : ofrs) {
82 if (auto attr = dyn_cast<Attribute>(ofr)) {
83 staticValues.push_back(cast<IntegerAttr>(attr).getInt());
84 } else {
85 staticValues.push_back(ShapedType::kDynamic);
86 }
87 }
88 return staticValues;
89 };
91 shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
92 getStaticValues(mixedStrides), generateErrorMessage);
93}
94
95LogicalResult
96mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
97 // A dynamic size is represented as ShapedType::kDynamic in `static_sizes`.
98 // Its corresponding Value appears in `sizes`. Thus, the number of dynamic
99 // dimensions in `static_sizes` must equal the rank of `sizes`.
100 // The same applies to strides and offsets.
101 size_t numDynamicDims =
102 llvm::count_if(op.getStaticSizes(), ShapedType::isDynamic);
103 if (op.getSizes().size() != numDynamicDims) {
104 return op->emitError("expected the number of 'sizes' to match the number "
105 "of dynamic entries in 'static_sizes' (")
106 << op.getSizes().size() << " vs " << numDynamicDims << ")";
107 }
108 size_t numDynamicStrides =
109 llvm::count_if(op.getStaticStrides(), ShapedType::isDynamic);
110 if (op.getStrides().size() != numDynamicStrides) {
111 return op->emitError("expected the number of 'strides' to match the number "
112 "of dynamic entries in 'static_strides' (")
113 << op.getStrides().size() << " vs " << numDynamicStrides << ")";
114 }
115 size_t numDynamicOffsets =
116 llvm::count_if(op.getStaticOffsets(), ShapedType::isDynamic);
117 if (op.getOffsets().size() != numDynamicOffsets) {
118 return op->emitError("expected the number of 'offsets' to match the number "
119 "of dynamic entries in 'static_offsets' (")
120 << op.getOffsets().size() << " vs " << numDynamicOffsets << ")";
121 }
122
123 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
124 // Offsets can come in 2 flavors:
125 // 1. Either single entry (when maxRanks == 1).
126 // 2. Or as an array whose rank must match that of the mixed sizes.
127 // So that the result type is well-formed.
128 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
129 op.getMixedOffsets().size() != op.getMixedSizes().size())
130 return op->emitError(
131 "expected mixed offsets rank to match mixed sizes rank (")
132 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
133 << ") so the rank of the result type is well-formed.";
134 // Ranks of mixed sizes and strides must always match so the result type is
135 // well-formed.
136 if (op.getMixedSizes().size() != op.getMixedStrides().size())
137 return op->emitError(
138 "expected mixed sizes rank to match mixed strides rank (")
139 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
140 << ") so the rank of the result type is well-formed.";
141
143 op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
144 return failure();
146 op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
147 return failure();
149 op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
150 return failure();
151
152 for (int64_t offset : op.getStaticOffsets()) {
153 if (offset < 0 && ShapedType::isStatic(offset))
154 return op->emitError("expected offsets to be non-negative, but got ")
155 << offset;
156 }
157 for (int64_t size : op.getStaticSizes()) {
158 if (size < 0 && ShapedType::isStatic(size))
159 return op->emitError("expected sizes to be non-negative, but got ")
160 << size;
161 }
162 return success();
163}
164
166 switch (delimiter) {
168 return '(';
170 return '<';
172 return '[';
174 return '{';
175 default:
176 llvm_unreachable("unsupported delimiter");
177 }
178}
179
181 switch (delimiter) {
183 return ')';
185 return '>';
187 return ']';
189 return '}';
190 default:
191 llvm_unreachable("unsupported delimiter");
192 }
193}
194
196 OperandRange values,
197 ArrayRef<int64_t> integers,
198 ArrayRef<bool> scalableFlags,
199 TypeRange valueTypes,
200 AsmParser::Delimiter delimiter) {
201 char leftDelimiter = getLeftDelimiter(delimiter);
202 char rightDelimiter = getRightDelimiter(delimiter);
203 printer << leftDelimiter;
204 if (integers.empty()) {
205 printer << rightDelimiter;
206 return;
207 }
208
209 unsigned dynamicValIdx = 0;
210 unsigned scalableIndexIdx = 0;
211 llvm::interleaveComma(integers, printer, [&](int64_t integer) {
212 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
213 printer << "[";
214 if (ShapedType::isDynamic(integer)) {
215 printer << values[dynamicValIdx];
216 if (!valueTypes.empty())
217 printer << " : " << valueTypes[dynamicValIdx];
218 ++dynamicValIdx;
219 } else {
220 printer << integer;
221 }
222 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
223 printer << "]";
224
225 scalableIndexIdx++;
226 });
227
228 printer << rightDelimiter;
229}
230
232 OpAsmParser &parser,
234 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
235 SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
236
237 SmallVector<int64_t, 4> integerVals;
238 SmallVector<bool, 4> scalableVals;
239 auto parseIntegerOrValue = [&]() {
241 auto res = parser.parseOptionalOperand(operand);
242
243 // When encountering `[`, assume that this is a scalable index.
244 scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
245
246 if (res.has_value() && succeeded(res.value())) {
247 values.push_back(operand);
248 integerVals.push_back(ShapedType::kDynamic);
249 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
250 return failure();
251 } else {
252 int64_t integer;
253 if (failed(parser.parseInteger(integer)))
254 return failure();
255 integerVals.push_back(integer);
256 }
257
258 // If this is assumed to be a scalable index, verify that there's a closing
259 // `]`.
260 if (scalableVals.back() && parser.parseOptionalRSquare().failed())
261 return failure();
262 return success();
263 };
264 if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
265 " in dynamic index list"))
266 return parser.emitError(parser.getNameLoc())
267 << "expected SSA value or integer";
268 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
269 scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
270 return success();
271}
272
274 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
276 if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
277 return false;
278 if (a.getStaticSizes().size() != b.getStaticSizes().size())
279 return false;
280 if (a.getStaticStrides().size() != b.getStaticStrides().size())
281 return false;
282 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
283 if (!cmp(std::get<0>(it), std::get<1>(it)))
284 return false;
285 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
286 if (!cmp(std::get<0>(it), std::get<1>(it)))
287 return false;
288 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
289 if (!cmp(std::get<0>(it), std::get<1>(it)))
290 return false;
291 return true;
292}
293
295 unsigned idx) {
296 return std::count_if(staticVals.begin(), staticVals.begin() + idx,
297 ShapedType::isDynamic);
298}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
Definition Builders.cpp:151
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
Result for slice bounds verification;.