MLIR  17.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  StringRef name,
22  unsigned numElements,
23  ArrayRef<int64_t> staticVals,
24  ValueRange values) {
25  // Check static and dynamic offsets/sizes/strides does not overflow type.
26  if (staticVals.size() != numElements)
27  return op->emitError("expected ") << numElements << " " << name
28  << " values, got " << staticVals.size();
29  unsigned expectedNumDynamicEntries =
30  llvm::count_if(staticVals, [&](int64_t staticVal) {
31  return ShapedType::isDynamic(staticVal);
32  });
33  if (values.size() != expectedNumDynamicEntries)
34  return op->emitError("expected ")
35  << expectedNumDynamicEntries << " dynamic " << name << " values";
36  return success();
37 }
38 
40 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42  // Offsets can come in 2 flavors:
43  // 1. Either single entry (when maxRanks == 1).
44  // 2. Or as an array whose rank must match that of the mixed sizes.
45  // So that the result type is well-formed.
46  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47  op.getMixedOffsets().size() != op.getMixedSizes().size())
48  return op->emitError(
49  "expected mixed offsets rank to match mixed sizes rank (")
50  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51  << ") so the rank of the result type is well-formed.";
52  // Ranks of mixed sizes and strides must always match so the result type is
53  // well-formed.
54  if (op.getMixedSizes().size() != op.getMixedStrides().size())
55  return op->emitError(
56  "expected mixed sizes rank to match mixed strides rank (")
57  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58  << ") so the rank of the result type is well-formed.";
59 
60  if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0],
61  op.static_offsets(), op.offsets())))
62  return failure();
63  if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
64  op.static_sizes(), op.sizes())))
65  return failure();
66  if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2],
67  op.static_strides(), op.strides())))
68  return failure();
69  return success();
70 }
71 
72 static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
73  switch (delimiter) {
75  return '(';
77  return '<';
79  return '[';
81  return '{';
82  default:
83  llvm_unreachable("unsupported delimiter");
84  }
85 }
86 
87 static char getRightDelimiter(AsmParser::Delimiter delimiter) {
88  switch (delimiter) {
90  return ')';
92  return '>';
94  return ']';
96  return '}';
97  default:
98  llvm_unreachable("unsupported delimiter");
99  }
100 }
101 
103  OperandRange values,
104  ArrayRef<int64_t> integers,
105  TypeRange valueTypes,
106  AsmParser::Delimiter delimiter,
107  bool isTrailingIdxScalable) {
108  char leftDelimiter = getLeftDelimiter(delimiter);
109  char rightDelimiter = getRightDelimiter(delimiter);
110  printer << leftDelimiter;
111  if (integers.empty()) {
112  printer << rightDelimiter;
113  return;
114  }
115 
116  int64_t trailingScalableInteger;
117  if (isTrailingIdxScalable) {
118  // ATM only the trailing idx can be scalable
119  trailingScalableInteger = integers.back();
120  integers = integers.drop_back();
121  }
122 
123  unsigned idx = 0;
124  llvm::interleaveComma(integers, printer, [&](int64_t integer) {
125  if (ShapedType::isDynamic(integer)) {
126  printer << values[idx];
127  if (!valueTypes.empty())
128  printer << " : " << valueTypes[idx];
129  ++idx;
130  } else {
131  printer << integer;
132  }
133  });
134 
135  // Print the trailing scalable index
136  if (isTrailingIdxScalable) {
137  printer << ", ";
138  printer << "[";
139  printer << trailingScalableInteger;
140  printer << "]";
141  }
142 
143  printer << rightDelimiter;
144 }
145 
147  OpAsmParser &parser,
149  DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
150  SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
151 
152  SmallVector<int64_t, 4> integerVals;
153  bool foundScalable = false;
154  auto parseIntegerOrValue = [&]() {
156  auto res = parser.parseOptionalOperand(operand);
157 
158  // If `foundScalable` has already been set to `true` then a non-trailing
159  // tile size was identified as scalable.
160  if (foundScalable) {
161  parser.emitError(parser.getNameLoc())
162  << "non-trailing tile size cannot be scalable";
163  return failure();
164  }
165 
166  if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
167  foundScalable = true;
168 
169  if (res.has_value() && succeeded(res.value())) {
170  values.push_back(operand);
171  integerVals.push_back(ShapedType::kDynamic);
172  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
173  return failure();
174  } else {
175  int64_t integer;
176  if (failed(parser.parseInteger(integer)))
177  return failure();
178  integerVals.push_back(integer);
179  }
180  if (foundScalable && parser.parseOptionalRSquare().failed())
181  return failure();
182  return success();
183  };
184  if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
185  " in dynamic index list"))
186  return parser.emitError(parser.getNameLoc())
187  << "expected SSA value or integer";
188  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
189  if (isTrailingIdxScalable)
190  *isTrailingIdxScalable = foundScalable;
191  return success();
192 }
193 
195  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
197  if (a.static_offsets().size() != b.static_offsets().size())
198  return false;
199  if (a.static_sizes().size() != b.static_sizes().size())
200  return false;
201  if (a.static_strides().size() != b.static_strides().size())
202  return false;
203  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
204  if (!cmp(std::get<0>(it), std::get<1>(it)))
205  return false;
206  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
207  if (!cmp(std::get<0>(it), std::get<1>(it)))
208  return false;
209  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
210  if (!cmp(std::get<0>(it), std::get<1>(it)))
211  return false;
212  return true;
213 }
static char getLeftDelimiter(AsmParser::Delimiter delimiter)
static char getRightDelimiter(AsmParser::Delimiter delimiter)
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Paren
Parens surrounding zero or more operands.
@ Braces
{} brackets surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
@ LessGreater
<> brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:170
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:265
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:266
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
This header declares functions that assit transformations in the MemRef dialect.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable=nullptr, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square, bool isTrailingIdxScalable=false)
Printer hook for custom directive in assemblyFormat.
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
This is the representation of an operand reference.