MLIR  14.0.0git
ViewLikeInterface.cpp
Go to the documentation of this file.
1 //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // ViewLike Interfaces
15 //===----------------------------------------------------------------------===//
16 
17 /// Include the definitions of the loop-like interfaces.
18 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19 
21  Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
22  ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
23  /// Check static and dynamic offsets/sizes/strides does not overflow type.
24  if (attr.size() != numElements)
25  return op->emitError("expected ")
26  << numElements << " " << name << " values";
27  unsigned expectedNumDynamicEntries =
28  llvm::count_if(attr.getValue(), [&](Attribute attr) {
29  return isDynamic(attr.cast<IntegerAttr>().getInt());
30  });
31  if (values.size() != expectedNumDynamicEntries)
32  return op->emitError("expected ")
33  << expectedNumDynamicEntries << " dynamic " << name << " values";
34  return success();
35 }
36 
38 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
39  std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
40  // Offsets can come in 2 flavors:
41  // 1. Either single entry (when maxRanks == 1).
42  // 2. Or as an array whose rank must match that of the mixed sizes.
43  // So that the result type is well-formed.
44  if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) &&
45  op.getMixedOffsets().size() != op.getMixedSizes().size())
46  return op->emitError(
47  "expected mixed offsets rank to match mixed sizes rank (")
48  << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
49  << ") so the rank of the result type is well-formed.";
50  // Ranks of mixed sizes and strides must always match so the result type is
51  // well-formed.
52  if (op.getMixedSizes().size() != op.getMixedStrides().size())
53  return op->emitError(
54  "expected mixed sizes rank to match mixed strides rank (")
55  << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
56  << ") so the rank of the result type is well-formed.";
57 
59  op, "offset", maxRanks[0], op.static_offsets(), op.offsets(),
60  ShapedType::isDynamicStrideOrOffset)))
61  return failure();
62  if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
63  op.static_sizes(), op.sizes(),
64  ShapedType::isDynamic)))
65  return failure();
67  op, "stride", maxRanks[2], op.static_strides(), op.strides(),
68  ShapedType::isDynamicStrideOrOffset)))
69  return failure();
70  return success();
71 }
72 
73 template <int64_t dynVal>
75  ArrayAttr arrayAttr) {
76  p << '[';
77  if (arrayAttr.empty()) {
78  p << "]";
79  return;
80  }
81  unsigned idx = 0;
82  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
83  int64_t val = a.cast<IntegerAttr>().getInt();
84  if (val == dynVal)
85  p << values[idx++];
86  else
87  p << val;
88  });
89  p << ']';
90 }
91 
93  Operation *op,
94  OperandRange values,
95  ArrayAttr integers) {
96  return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
97  p, values, integers);
98 }
99 
101  OperandRange values,
102  ArrayAttr integers) {
103  return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
104  integers);
105 }
106 
107 template <int64_t dynVal>
108 static ParseResult
111  ArrayAttr &integers) {
112  if (failed(parser.parseLSquare()))
113  return failure();
114  // 0-D.
115  if (succeeded(parser.parseOptionalRSquare())) {
116  integers = parser.getBuilder().getArrayAttr({});
117  return success();
118  }
119 
120  SmallVector<int64_t, 4> attrVals;
121  while (true) {
122  OpAsmParser::OperandType operand;
123  auto res = parser.parseOptionalOperand(operand);
124  if (res.hasValue() && succeeded(res.getValue())) {
125  values.push_back(operand);
126  attrVals.push_back(dynVal);
127  } else {
128  IntegerAttr attr;
129  if (failed(parser.parseAttribute<IntegerAttr>(attr)))
130  return parser.emitError(parser.getNameLoc())
131  << "expected SSA value or integer";
132  attrVals.push_back(attr.getInt());
133  }
134 
135  if (succeeded(parser.parseOptionalComma()))
136  continue;
137  if (failed(parser.parseRSquare()))
138  return failure();
139  break;
140  }
141  integers = parser.getBuilder().getI64ArrayAttr(attrVals);
142  return success();
143 }
144 
147  ArrayAttr &integers) {
148  return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
149  parser, values, integers);
150 }
151 
154  ArrayAttr &integers) {
155  return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
156  integers);
157 }
158 
160  OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
162  if (a.static_offsets().size() != b.static_offsets().size())
163  return false;
164  if (a.static_sizes().size() != b.static_sizes().size())
165  return false;
166  if (a.static_strides().size() != b.static_strides().size())
167  return false;
168  for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
169  if (!cmp(std::get<0>(it), std::get<1>(it)))
170  return false;
171  for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
172  if (!cmp(std::get<0>(it), std::get<1>(it)))
173  return false;
174  for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
175  if (!cmp(std::get<0>(it), std::get<1>(it)))
176  return false;
177  return true;
178 }
179 
180 void OffsetSizeAndStrideOpInterface::expandToRank(
181  Value target, SmallVector<OpFoldResult> &offsets,
183  llvm::function_ref<OpFoldResult(Value, int64_t)> createOrFoldDim) {
184  auto shapedType = target.getType().cast<ShapedType>();
185  unsigned rank = shapedType.getRank();
186  assert(offsets.size() == sizes.size() && "mismatched lengths");
187  assert(offsets.size() == strides.size() && "mismatched lengths");
188  assert(offsets.size() <= rank && "rank overflow");
189  MLIRContext *ctx = target.getContext();
190  Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0));
191  Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1));
192  for (unsigned i = offsets.size(); i < rank; ++i) {
193  offsets.push_back(zero);
194  sizes.push_back(createOrFoldDim(target, i));
195  strides.push_back(one);
196  }
197 }
198 
200 mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
201  ArrayAttr staticOffsets, ValueRange offsets) {
203  unsigned numDynamic = 0;
204  unsigned count = static_cast<unsigned>(staticOffsets.size());
205  for (unsigned idx = 0; idx < count; ++idx) {
206  if (op.isDynamicOffset(idx))
207  res.push_back(offsets[numDynamic++]);
208  else
209  res.push_back(staticOffsets[idx]);
210  }
211  return res;
212 }
213 
215 mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
216  ValueRange sizes) {
218  unsigned numDynamic = 0;
219  unsigned count = static_cast<unsigned>(staticSizes.size());
220  for (unsigned idx = 0; idx < count; ++idx) {
221  if (op.isDynamicSize(idx))
222  res.push_back(sizes[numDynamic++]);
223  else
224  res.push_back(staticSizes[idx]);
225  }
226  return res;
227 }
228 
230 mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
231  ArrayAttr staticStrides, ValueRange strides) {
233  unsigned numDynamic = 0;
234  unsigned count = static_cast<unsigned>(staticStrides.size());
235  for (unsigned idx = 0; idx < count; ++idx) {
236  if (op.isDynamicStride(idx))
237  res.push_back(strides[numDynamic++]);
238  else
239  res.push_back(staticStrides[idx]);
240  }
241  return res;
242 }
This is the representation of an operand reference.
Include the generated interface declarations.
SmallVector< OpFoldResult, 4 > getMixedOffsets(OffsetSizeAndStrideOpInterface op, ArrayAttr staticOffsets, ValueRange offsets)
Return a vector of all the static or dynamic offsets of the op from provided external static and dyna...
U cast() const
Definition: Attributes.h:123
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayAttr integers)
Printer hook for custom directive in assemblyFormat.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:214
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, ValueRange values, llvm::function_ref< bool(int64_t)> isDynamic)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayAttr integers)
Printer hook for custom directive in assemblyFormat.
virtual ParseResult parseLSquare()=0
Parse a [ token.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< OpFoldResult, 4 > getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes)
Return a vector of all the static or dynamic sizes of the op from provided external static and dynami...
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual OptionalParseResult parseOptionalOperand(OperandType &result)=0
Parse a single operand if present.
virtual ParseResult parseRSquare()=0
Parse a ] token.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
Type getType() const
Return the type of this value.
Definition: Value.h:117
ParseResult parseOperandsOrIntegersSizesList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::OperandType > &values, ArrayAttr &integers)
Pasrer hook for custom directive in assemblyFormat.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class implements the operand iterators for the Operation class.
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op)
static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr)
static ParseResult parseOperandsOrIntegersImpl(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::OperandType > &values, ArrayAttr &integers)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
SmallVector< OpFoldResult, 4 > getMixedStrides(OffsetSizeAndStrideOpInterface op, ArrayAttr staticStrides, ValueRange strides)
Return a vector of all the static or dynamic strides of the op from provided external static and dyna...
ParseResult parseOperandsOrIntegersOffsetsOrStridesList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::OperandType > &values, ArrayAttr &integers)
Pasrer hook for custom directive in assemblyFormat.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:120
bool sameOffsetsSizesAndStrides(OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref< bool(OpFoldResult, OpFoldResult)> cmp)
U cast() const
Definition: Types.h:250