MLIR  21.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  StringRef name,
22  unsigned numElements,
23  ArrayRef<int64_t> staticVals,
24  ValueRange values) {
25  // Check static and dynamic offsets/sizes/strides does not overflow type.
26  if (staticVals.size() != numElements)
27  return op->emitError("expected ") << numElements << " " << name
28  << " values, got " << staticVals.size();
29  unsigned expectedNumDynamicEntries =
30  llvm::count_if(staticVals, [](int64_t staticVal) {
31  return ShapedType::isDynamic(staticVal);
32  });
33  if (values.size() != expectedNumDynamicEntries)
34  return op->emitError("expected ")
35  << expectedNumDynamicEntries << " dynamic " << name << " values";
36  return success();
37 }
38 
40  ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
41  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
42  bool generateErrorMessage) {
44  result.isValid = true;
45  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
46  // Nothing to verify for dynamic source dims.
47  if (ShapedType::isDynamic(shape[i]))
48  continue;
49  // Nothing to verify if the offset is dynamic.
50  if (ShapedType::isDynamic(staticOffsets[i]))
51  continue;
52  if (staticOffsets[i] >= shape[i]) {
53  result.errorMessage =
54  std::string("offset ") + std::to_string(i) +
55  " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
56  " >= " + std::to_string(shape[i]);
57  result.isValid = false;
58  return result;
59  }
60  if (ShapedType::isDynamic(staticSizes[i]) ||
61  ShapedType::isDynamic(staticStrides[i]))
62  continue;
63  int64_t lastPos =
64  staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
65  if (lastPos >= shape[i]) {
66  result.errorMessage = std::string("slice along dimension ") +
67  std::to_string(i) +
68  " runs out-of-bounds: " + std::to_string(lastPos) +
69  " >= " + std::to_string(shape[i]);
70  result.isValid = false;
71  return result;
72  }
73  }
74  return result;
75 }
76 
78  ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
79  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
80  bool generateErrorMessage) {
81  auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
82  SmallVector<int64_t> staticValues;
83  for (OpFoldResult ofr : ofrs) {
84  if (auto attr = dyn_cast<Attribute>(ofr)) {
85  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
86  } else {
87  staticValues.push_back(ShapedType::kDynamic);
88  }
89  }
90  return staticValues;
91  };
92  return verifyInBoundsSlice(
93  shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
94  getStaticValues(mixedStrides), generateErrorMessage);
95 }
96 
97 LogicalResult
98 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
99  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
100  // Offsets can come in 2 flavors:
101  // 1. Either single entry (when maxRanks == 1).
102  // 2. Or as an array whose rank must match that of the mixed sizes.
103  // So that the result type is well-formed.
104  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
105  op.getMixedOffsets().size() != op.getMixedSizes().size())
106  return op->emitError(
107  "expected mixed offsets rank to match mixed sizes rank (")
108  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
109  << ") so the rank of the result type is well-formed.";
110  // Ranks of mixed sizes and strides must always match so the result type is
111  // well-formed.
112  if (op.getMixedSizes().size() != op.getMixedStrides().size())
113  return op->emitError(
114  "expected mixed sizes rank to match mixed strides rank (")
115  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
116  << ") so the rank of the result type is well-formed.";
117 
119  op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
120  return failure();
122  op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
123  return failure();
125  op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
126  return failure();
127 
128  for (int64_t offset : op.getStaticOffsets()) {
129  if (offset < 0 && !ShapedType::isDynamic(offset))
130  return op->emitError("expected offsets to be non-negative, but got ")
131  << offset;
132  }
133  for (int64_t size : op.getStaticSizes()) {
134  if (size < 0 && !ShapedType::isDynamic(size))
135  return op->emitError("expected sizes to be non-negative, but got ")
136  << size;
137  }
138  return success();
139 }
140 
141 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
142  switch (delimiter) {
144  return '(';
146  return '<';
148  return '[';
150  return '{';
151  default:
152  llvm_unreachable("unsupported delimiter");
153  }
154 }
155 
156 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
157  switch (delimiter) {
159  return ')';
161  return '>';
163  return ']';
165  return '}';
166  default:
167  llvm_unreachable("unsupported delimiter");
168  }
169 }
170 
172  OperandRange values,
173  ArrayRef<int64_t> integers,
174  ArrayRef<bool> scalableFlags,
175  TypeRange valueTypes,
176  AsmParser::Delimiter delimiter) {
177  char leftDelimiter = getLeftDelimiter(delimiter);
178  char rightDelimiter = getRightDelimiter(delimiter);
179  printer << leftDelimiter;
180  if (integers.empty()) {
181  printer << rightDelimiter;
182  return;
183  }
184 
185  unsigned dynamicValIdx = 0;
186  unsigned scalableIndexIdx = 0;
187  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
188  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
189  printer << "[";
190  if (ShapedType::isDynamic(integer)) {
191  printer << values[dynamicValIdx];
192  if (!valueTypes.empty())
193  printer << " : " << valueTypes[dynamicValIdx];
194  ++dynamicValIdx;
195  } else {
196  printer << integer;
197  }
198  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
199  printer << "]";
200 
201  scalableIndexIdx++;
202  });
203 
204  printer << rightDelimiter;
205 }
206 
208  OpAsmParser &parser,
210  DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
211  SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
212 
213  SmallVector<int64_t, 4> integerVals;
214  SmallVector<bool, 4> scalableVals;
215  auto parseIntegerOrValue = [&]() {
217  auto res = parser.parseOptionalOperand(operand);
218 
219  // When encountering `[`, assume that this is a scalable index.
220  scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
221 
222  if (res.has_value() && succeeded(res.value())) {
223  values.push_back(operand);
224  integerVals.push_back(ShapedType::kDynamic);
225  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
226  return failure();
227  } else {
228  int64_t integer;
229  if (failed(parser.parseInteger(integer)))
230  return failure();
231  integerVals.push_back(integer);
232  }
233 
234  // If this is assumed to be a scalable index, verify that there's a closing
235  // `]`.
236  if (scalableVals.back() && parser.parseOptionalRSquare().failed())
237  return failure();
238  return success();
239  };
240  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
241  " in dynamic index list"))
242  return parser.emitError(parser.getNameLoc())
243  << "expected SSA value or integer";
244  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
245  scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
246  return success();
247 }
248 
250  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
252  if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
253  return false;
254  if (a.getStaticSizes().size() != b.getStaticSizes().size())
255  return false;
256  if (a.getStaticStrides().size() != b.getStaticStrides().size())
257  return false;
258  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
259  if (!cmp(std::get<0>(it), std::get<1>(it)))
260  return false;
261  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
262  if (!cmp(std::get<0>(it), std::get<1>(it)))
263  return false;
264  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
265  if (!cmp(std::get<0>(it), std::get<1>(it)))
266  return false;
267  return true;
268 }
269 
271  unsigned idx) {
272  return std::count_if(staticVals.begin(), staticVals.begin() + idx,
273  [&](int64_t val) { return ShapedType::isDynamic(val); });
274 }
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
Definition: Builders.cpp:147
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
Include the generated interface declarations.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.