MLIR  21.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  StringRef name,
22  unsigned numElements,
23  ArrayRef<int64_t> staticVals,
24  ValueRange values) {
25  // Check static and dynamic offsets/sizes/strides does not overflow type.
26  if (staticVals.size() != numElements)
27  return op->emitError("expected ") << numElements << " " << name
28  << " values, got " << staticVals.size();
29  unsigned expectedNumDynamicEntries =
30  llvm::count_if(staticVals, [](int64_t staticVal) {
31  return ShapedType::isDynamic(staticVal);
32  });
33  if (values.size() != expectedNumDynamicEntries)
34  return op->emitError("expected ")
35  << expectedNumDynamicEntries << " dynamic " << name << " values";
36  return success();
37 }
38 
39 LogicalResult
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42  // Offsets can come in 2 flavors:
43  // 1. Either single entry (when maxRanks == 1).
44  // 2. Or as an array whose rank must match that of the mixed sizes.
45  // So that the result type is well-formed.
46  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47  op.getMixedOffsets().size() != op.getMixedSizes().size())
48  return op->emitError(
49  "expected mixed offsets rank to match mixed sizes rank (")
50  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51  << ") so the rank of the result type is well-formed.";
52  // Ranks of mixed sizes and strides must always match so the result type is
53  // well-formed.
54  if (op.getMixedSizes().size() != op.getMixedStrides().size())
55  return op->emitError(
56  "expected mixed sizes rank to match mixed strides rank (")
57  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58  << ") so the rank of the result type is well-formed.";
59 
61  op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
62  return failure();
64  op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
65  return failure();
67  op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
68  return failure();
69 
70  for (int64_t offset : op.getStaticOffsets()) {
71  if (offset < 0 && !ShapedType::isDynamic(offset))
72  return op->emitError("expected offsets to be non-negative, but got ")
73  << offset;
74  }
75  for (int64_t size : op.getStaticSizes()) {
76  if (size < 0 && !ShapedType::isDynamic(size))
77  return op->emitError("expected sizes to be non-negative, but got ")
78  << size;
79  }
80  return success();
81 }
82 
83 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
84  switch (delimiter) {
86  return '(';
88  return '<';
90  return '[';
92  return '{';
93  default:
94  llvm_unreachable("unsupported delimiter");
95  }
96 }
97 
98 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
99  switch (delimiter) {
101  return ')';
103  return '>';
105  return ']';
107  return '}';
108  default:
109  llvm_unreachable("unsupported delimiter");
110  }
111 }
112 
114  OperandRange values,
115  ArrayRef<int64_t> integers,
116  ArrayRef<bool> scalableFlags,
117  TypeRange valueTypes,
118  AsmParser::Delimiter delimiter) {
119  char leftDelimiter = getLeftDelimiter(delimiter);
120  char rightDelimiter = getRightDelimiter(delimiter);
121  printer << leftDelimiter;
122  if (integers.empty()) {
123  printer << rightDelimiter;
124  return;
125  }
126 
127  unsigned dynamicValIdx = 0;
128  unsigned scalableIndexIdx = 0;
129  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
130  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
131  printer << "[";
132  if (ShapedType::isDynamic(integer)) {
133  printer << values[dynamicValIdx];
134  if (!valueTypes.empty())
135  printer << " : " << valueTypes[dynamicValIdx];
136  ++dynamicValIdx;
137  } else {
138  printer << integer;
139  }
140  if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
141  printer << "]";
142 
143  scalableIndexIdx++;
144  });
145 
146  printer << rightDelimiter;
147 }
148 
150  OpAsmParser &parser,
152  DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
153  SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
154 
155  SmallVector<int64_t, 4> integerVals;
156  SmallVector<bool, 4> scalableVals;
157  auto parseIntegerOrValue = [&]() {
159  auto res = parser.parseOptionalOperand(operand);
160 
161  // When encountering `[`, assume that this is a scalable index.
162  scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
163 
164  if (res.has_value() && succeeded(res.value())) {
165  values.push_back(operand);
166  integerVals.push_back(ShapedType::kDynamic);
167  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
168  return failure();
169  } else {
170  int64_t integer;
171  if (failed(parser.parseInteger(integer)))
172  return failure();
173  integerVals.push_back(integer);
174  }
175 
176  // If this is assumed to be a scalable index, verify that there's a closing
177  // `]`.
178  if (scalableVals.back() && parser.parseOptionalRSquare().failed())
179  return failure();
180  return success();
181  };
182  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
183  " in dynamic index list"))
184  return parser.emitError(parser.getNameLoc())
185  << "expected SSA value or integer";
186  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
187  scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
188  return success();
189 }
190 
192  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
194  if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
195  return false;
196  if (a.getStaticSizes().size() != b.getStaticSizes().size())
197  return false;
198  if (a.getStaticStrides().size() != b.getStaticStrides().size())
199  return false;
200  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
201  if (!cmp(std::get<0>(it), std::get<1>(it)))
202  return false;
203  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
204  if (!cmp(std::get<0>(it), std::get<1>(it)))
205  return false;
206  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
207  if (!cmp(std::get<0>(it), std::get<1>(it)))
208  return false;
209  return true;
210 }
211 
213  unsigned idx) {
214  return std::count_if(staticVals.begin(), staticVals.begin() + idx,
215  [&](int64_t val) { return ShapedType::isDynamic(val); });
216 }
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
Definition: Builders.cpp:147
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
unsigned getNumDynamicEntriesUpToIdx(ArrayRef< int64_t > staticVals, unsigned idx)
Helper method to compute the number of dynamic entries of staticVals, up to idx.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
Include the generated interface declarations.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.