MLIR  22.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  StringRef name,
22  unsigned numElements,
23  ArrayRef<int64_t> staticVals,
24  ValueRange values) {
25  // Check static and dynamic offsets/sizes/strides does not overflow type.
26  if (staticVals.size() != numElements)
27  return op->emitError("expected ") << numElements << " " << name
28  << " values, got " << staticVals.size();
29  unsigned expectedNumDynamicEntries =
30  llvm::count_if(staticVals, ShapedType::isDynamic);
31  if (values.size() != expectedNumDynamicEntries)
32  return op->emitError("expected ")
33  << expectedNumDynamicEntries << " dynamic " << name << " values";
34  return success();
35 }
36 
38  ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
39  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
40  bool generateErrorMessage) {
42  result.isValid = true;
43  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
44  // Nothing to verify for dynamic source dims.
45  if (ShapedType::isDynamic(shape[i]))
46  continue;
47  // Nothing to verify if the offset is dynamic.
48  if (ShapedType::isDynamic(staticOffsets[i]))
49  continue;
50  if (staticOffsets[i] >= shape[i]) {
51  result.errorMessage =
52  std::string("offset ") + std::to_string(i) +
53  " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
54  " >= " + std::to_string(shape[i]);
55  result.isValid = false;
56  return result;
57  }
58  if (ShapedType::isDynamic(staticSizes[i]) ||
59  ShapedType::isDynamic(staticStrides[i]))
60  continue;
61  int64_t lastPos =
62  staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
63  if (lastPos >= shape[i]) {
64  result.errorMessage = std::string("slice along dimension ") +
65  std::to_string(i) +
66  " runs out-of-bounds: " + std::to_string(lastPos) +
67  " >= " + std::to_string(shape[i]);
68  result.isValid = false;
69  return result;
70  }
71  }
72  return result;
73 }
74 
76  ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
77  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
78  bool generateErrorMessage) {
79  auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
80  SmallVector<int64_t> staticValues;
81  for (OpFoldResult ofr : ofrs) {
82  if (auto attr = dyn_cast<Attribute>(ofr)) {
83  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
84  } else {
85  staticValues.push_back(ShapedType::kDynamic);
86  }
87  }
88  return staticValues;
89  };
90  return verifyInBoundsSlice(
91  shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
92  getStaticValues(mixedStrides), generateErrorMessage);
93 }
94 
95 LogicalResult
96 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
97  // A dynamic size is represented as ShapedType::kDynamic in `static_sizes`.
98  // Its corresponding Value appears in `sizes`. Thus, the number of dynamic
99  // dimensions in `static_sizes` must equal the rank of `sizes`.
100  // The same applies to strides and offsets.
101  size_t numDynamicDims =
102  llvm::count_if(op.getStaticSizes(), ShapedType::isDynamic);
103  if (op.getSizes().size() != numDynamicDims) {
104  return op->emitError("expected the number of 'sizes' to match the number "
105  "of dynamic entries in 'static_sizes' (")
106  << op.getSizes().size() << " vs " << numDynamicDims << ")";
107  }
108  size_t numDynamicStrides =
109  llvm::count_if(op.getStaticStrides(), ShapedType::isDynamic);
110  if (op.getStrides().size() != numDynamicStrides) {
111  return op->emitError("expected the number of 'strides' to match the number "
112  "of dynamic entries in 'static_strides' (")
113  << op.getStrides().size() << " vs " << numDynamicStrides << ")";
114  }
115  size_t numDynamicOffsets =
116  llvm::count_if(op.getStaticOffsets(), ShapedType::isDynamic);
117  if (op.getOffsets().size() != numDynamicOffsets) {
118  return op->emitError("expected the number of 'offsets' to match the number "
119  "of dynamic entries in 'static_offsets' (")
120  << op.getOffsets().size() << " vs " << numDynamicOffsets << ")";
121  }
122 
123  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
124  // Offsets can come in 2 flavors:
125  // 1. Either single entry (when maxRanks == 1).
126  // 2. Or as an array whose rank must match that of the mixed sizes.
127  // So that the result type is well-formed.
128  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
129  op.getMixedOffsets().size() != op.getMixedSizes().size())
130  return op->emitError(
131  "expected mixed offsets rank to match mixed sizes rank (")
132  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
133  << ") so the rank of the result type is well-formed.";
134  // Ranks of mixed sizes and strides must always match so the result type is
135  // well-formed.
136  if (op.getMixedSizes().size() != op.getMixedStrides().size())
137  return op->emitError(
138  "expected mixed sizes rank to match mixed strides rank (")
139  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
140  << ") so the rank of the result type is well-formed.";
141 
143  op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
144  return failure();
146  op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
147  return failure();
149  op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
150  return failure();
151 
152  for (int64_t offset : op.getStaticOffsets()) {
153  if (offset < 0 && ShapedType::isStatic(offset))
154  return op->emitError("expected offsets to be non-negative, but got ")
155  << offset;
156  }
157  for (int64_t size : op.getStaticSizes()) {
158  if (size < 0 && ShapedType::isStatic(size))
159  return op->emitError("expected sizes to be non-negative, but got ")
160  << size;
161  }
162  return success();
163 }
164 
165 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
166  switch (delimiter) {
168  return '(';
170  return '<';
172  return '[';
174  return '{';
175  default:
176  llvm_unreachable("unsupported delimiter");
177  }
178 }
179 
180 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
181  switch (delimiter) {
183  return ')';
185  return '>';
187  return ']';
189  return '}';
190  default:
191  llvm_unreachable("unsupported delimiter");
192  }
193 }
194 
196  OperandRange values,
197  ArrayRef<int64_t> integers,
198  ArrayRef<bool> scalableFlags,
199  TypeRange valueTypes,
200  AsmParser::Delimiter delimiter) {
201  char leftDelimiter = getLeftDelimiter(delimiter);
202  char rightDelimiter = getRightDelimiter(delimiter);
203  printer << leftDelimiter;
204  if (integers.empty()) {
205  printer << rightDelimiter;
206  return;
207  }
208 
209  unsigned dynamicValIdx = 0;
210  unsigned scalableIndexIdx = 0;
211  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
212  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
213  printer << "[";
214  if (ShapedType::isDynamic(integer)) {
215  printer << values[dynamicValIdx];
216  if (!valueTypes.empty())
217  printer << " : " << valueTypes[dynamicValIdx];
218  ++dynamicValIdx;
219  } else {
220  printer << integer;
221  }
222  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
223  printer << "]";
224 
225  scalableIndexIdx++;
226  });
227 
228  printer << rightDelimiter;
229 }
230 
232  OpAsmParser &parser,
234  DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
235  SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
236 
237  SmallVector<int64_t, 4> integerVals;
238  SmallVector<bool, 4> scalableVals;
239  auto parseIntegerOrValue = [&]() {
241  auto res = parser.parseOptionalOperand(operand);
242 
243  // When encountering `[`, assume that this is a scalable index.
244  scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
245 
246  if (res.has_value() && succeeded(res.value())) {
247  values.push_back(operand);
248  integerVals.push_back(ShapedType::kDynamic);
249  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
250  return failure();
251  } else {
252  int64_t integer;
253  if (failed(parser.parseInteger(integer)))
254  return failure();
255  integerVals.push_back(integer);
256  }
257 
258  // If this is assumed to be a scalable index, verify that there's a closing
259  // `]`.
260  if (scalableVals.back() && parser.parseOptionalRSquare().failed())
261  return failure();
262  return success();
263  };
264  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
265  " in dynamic index list"))
266  return parser.emitError(parser.getNameLoc())
267  << "expected SSA value or integer";
268  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
269  scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
270  return success();
271 }
272 
274  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
276  if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
277  return false;
278  if (a.getStaticSizes().size() != b.getStaticSizes().size())
279  return false;
280  if (a.getStaticStrides().size() != b.getStaticStrides().size())
281  return false;
282  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
283  if (!cmp(std::get<0>(it), std::get<1>(it)))
284  return false;
285  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
286  if (!cmp(std::get<0>(it), std::get<1>(it)))
287  return false;
288  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
289  if (!cmp(std::get<0>(it), std::get<1>(it)))
290  return false;
291  return true;
292 }
293 
295  unsigned idx) {
296  return std::count_if(staticVals.begin(), staticVals.begin() + idx,
297  ShapedType::isDynamic);
298 }
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
Definition: Builders.cpp:146
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.