MLIR  21.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  StringRef name,
22  unsigned numElements,
23  ArrayRef<int64_t> staticVals,
24  ValueRange values) {
25  // Check static and dynamic offsets/sizes/strides does not overflow type.
26  if (staticVals.size() != numElements)
27  return op->emitError("expected ") << numElements << " " << name
28  << " values, got " << staticVals.size();
29  unsigned expectedNumDynamicEntries =
30  llvm::count_if(staticVals, ShapedType::isDynamic);
31  if (values.size() != expectedNumDynamicEntries)
32  return op->emitError("expected ")
33  << expectedNumDynamicEntries << " dynamic " << name << " values";
34  return success();
35 }
36 
38  ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
39  ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
40  bool generateErrorMessage) {
42  result.isValid = true;
43  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
44  // Nothing to verify for dynamic source dims.
45  if (ShapedType::isDynamic(shape[i]))
46  continue;
47  // Nothing to verify if the offset is dynamic.
48  if (ShapedType::isDynamic(staticOffsets[i]))
49  continue;
50  if (staticOffsets[i] >= shape[i]) {
51  result.errorMessage =
52  std::string("offset ") + std::to_string(i) +
53  " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
54  " >= " + std::to_string(shape[i]);
55  result.isValid = false;
56  return result;
57  }
58  if (ShapedType::isDynamic(staticSizes[i]) ||
59  ShapedType::isDynamic(staticStrides[i]))
60  continue;
61  int64_t lastPos =
62  staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
63  if (lastPos >= shape[i]) {
64  result.errorMessage = std::string("slice along dimension ") +
65  std::to_string(i) +
66  " runs out-of-bounds: " + std::to_string(lastPos) +
67  " >= " + std::to_string(shape[i]);
68  result.isValid = false;
69  return result;
70  }
71  }
72  return result;
73 }
74 
76  ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
77  ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
78  bool generateErrorMessage) {
79  auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
80  SmallVector<int64_t> staticValues;
81  for (OpFoldResult ofr : ofrs) {
82  if (auto attr = dyn_cast<Attribute>(ofr)) {
83  staticValues.push_back(cast<IntegerAttr>(attr).getInt());
84  } else {
85  staticValues.push_back(ShapedType::kDynamic);
86  }
87  }
88  return staticValues;
89  };
90  return verifyInBoundsSlice(
91  shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
92  getStaticValues(mixedStrides), generateErrorMessage);
93 }
94 
95 LogicalResult
96 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
97  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
98  // Offsets can come in 2 flavors:
99  // 1. Either single entry (when maxRanks == 1).
100  // 2. Or as an array whose rank must match that of the mixed sizes.
101  // So that the result type is well-formed.
102  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
103  op.getMixedOffsets().size() != op.getMixedSizes().size())
104  return op->emitError(
105  "expected mixed offsets rank to match mixed sizes rank (")
106  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
107  << ") so the rank of the result type is well-formed.";
108  // Ranks of mixed sizes and strides must always match so the result type is
109  // well-formed.
110  if (op.getMixedSizes().size() != op.getMixedStrides().size())
111  return op->emitError(
112  "expected mixed sizes rank to match mixed strides rank (")
113  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
114  << ") so the rank of the result type is well-formed.";
115 
117  op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
118  return failure();
120  op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
121  return failure();
123  op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
124  return failure();
125 
126  for (int64_t offset : op.getStaticOffsets()) {
127  if (offset < 0 && !ShapedType::isDynamic(offset))
128  return op->emitError("expected offsets to be non-negative, but got ")
129  << offset;
130  }
131  for (int64_t size : op.getStaticSizes()) {
132  if (size < 0 && !ShapedType::isDynamic(size))
133  return op->emitError("expected sizes to be non-negative, but got ")
134  << size;
135  }
136  return success();
137 }
138 
139 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
140  switch (delimiter) {
142  return '(';
144  return '<';
146  return '[';
148  return '{';
149  default:
150  llvm_unreachable("unsupported delimiter");
151  }
152 }
153 
154 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
155  switch (delimiter) {
157  return ')';
159  return '>';
161  return ']';
163  return '}';
164  default:
165  llvm_unreachable("unsupported delimiter");
166  }
167 }
168 
170  OperandRange values,
171  ArrayRef<int64_t> integers,
172  ArrayRef<bool> scalableFlags,
173  TypeRange valueTypes,
174  AsmParser::Delimiter delimiter) {
175  char leftDelimiter = getLeftDelimiter(delimiter);
176  char rightDelimiter = getRightDelimiter(delimiter);
177  printer << leftDelimiter;
178  if (integers.empty()) {
179  printer << rightDelimiter;
180  return;
181  }
182 
183  unsigned dynamicValIdx = 0;
184  unsigned scalableIndexIdx = 0;
185  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
186  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
187  printer << "[";
188  if (ShapedType::isDynamic(integer)) {
189  printer << values[dynamicValIdx];
190  if (!valueTypes.empty())
191  printer << " : " << valueTypes[dynamicValIdx];
192  ++dynamicValIdx;
193  } else {
194  printer << integer;
195  }
196  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
197  printer << "]";
198 
199  scalableIndexIdx++;
200  });
201 
202  printer << rightDelimiter;
203 }
204 
206  OpAsmParser &parser,
208  DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
209  SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
210 
211  SmallVector<int64_t, 4> integerVals;
212  SmallVector<bool, 4> scalableVals;
213  auto parseIntegerOrValue = [&]() {
215  auto res = parser.parseOptionalOperand(operand);
216 
217  // When encountering `[`, assume that this is a scalable index.
218  scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
219 
220  if (res.has_value() && succeeded(res.value())) {
221  values.push_back(operand);
222  integerVals.push_back(ShapedType::kDynamic);
223  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
224  return failure();
225  } else {
226  int64_t integer;
227  if (failed(parser.parseInteger(integer)))
228  return failure();
229  integerVals.push_back(integer);
230  }
231 
232  // If this is assumed to be a scalable index, verify that there's a closing
233  // `]`.
234  if (scalableVals.back() && parser.parseOptionalRSquare().failed())
235  return failure();
236  return success();
237  };
238  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
239  " in dynamic index list"))
240  return parser.emitError(parser.getNameLoc())
241  << "expected SSA value or integer";
242  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
243  scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
244  return success();
245 }
246 
248  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
250  if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
251  return false;
252  if (a.getStaticSizes().size() != b.getStaticSizes().size())
253  return false;
254  if (a.getStaticStrides().size() != b.getStaticStrides().size())
255  return false;
256  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
257  if (!cmp(std::get<0>(it), std::get<1>(it)))
258  return false;
259  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
260  if (!cmp(std::get<0>(it), std::get<1>(it)))
261  return false;
262  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
263  if (!cmp(std::get<0>(it), std::get<1>(it)))
264  return false;
265  return true;
266 }
267 
269  unsigned idx) {
270  return std::count_if(staticVals.begin(), staticVals.begin() + idx,
271  ShapedType::isDynamic);
272 }
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
Definition: Builders.cpp:149
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
Include the generated interface declarations.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.