MLIR  19.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 
21 using namespace mlir;
22 
23 /// Generate an error message string for the given op and the specified error.
24 static std::string generateErrorMessage(Operation *op, const std::string &msg) {
25  std::string buffer;
26  llvm::raw_string_ostream stream(buffer);
27  OpPrintingFlags flags;
28  // We may generate a lot of error messages and so we need to ensure the
29  // printing is fast.
31  flags.printGenericOpForm();
32  flags.skipRegions();
33  flags.useLocalScope();
34  stream << "ERROR: Runtime op verification failed\n";
35  op->print(stream, flags);
36  stream << "\n^ " << msg;
37  stream << "\nLocation: ";
38  op->getLoc().print(stream);
39  return stream.str();
40 }
41 
42 namespace mlir {
43 namespace memref {
44 namespace {
45 struct CastOpInterface
46  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
47  CastOp> {
48  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
49  Location loc) const {
50  auto castOp = cast<CastOp>(op);
51  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
52 
53  // Nothing to check if the result is an unranked memref.
54  auto resultType = dyn_cast<MemRefType>(castOp.getType());
55  if (!resultType)
56  return;
57 
58  if (isa<UnrankedMemRefType>(srcType)) {
59  // Check rank.
60  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
61  Value resultRank =
62  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
63  Value isSameRank = builder.create<arith::CmpIOp>(
64  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
65  builder.create<cf::AssertOp>(loc, isSameRank,
66  generateErrorMessage(op, "rank mismatch"));
67  }
68 
69  // Get source offset and strides. We do not have an op to get offsets and
70  // strides from unranked memrefs, so cast the source to a type with fully
71  // dynamic layout, from which we can then extract the offset and strides.
72  // (Rank was already verified.)
73  int64_t dynamicOffset = ShapedType::kDynamic;
74  SmallVector<int64_t> dynamicShape(resultType.getRank(),
75  ShapedType::kDynamic);
76  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
77  dynamicOffset, dynamicShape);
78  auto dynStridesType =
79  MemRefType::get(dynamicShape, resultType.getElementType(),
80  stridedLayout, resultType.getMemorySpace());
81  Value helperCast =
82  builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
83  auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
84 
85  // Check dimension sizes.
86  for (const auto &it : llvm::enumerate(resultType.getShape())) {
87  // Static dim size -> static/dynamic dim size does not need verification.
88  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
89  if (!rankedSrcType.isDynamicDim(it.index()))
90  continue;
91 
92  // Static/dynamic dim size -> dynamic dim size does not need verification.
93  if (resultType.isDynamicDim(it.index()))
94  continue;
95 
96  Value srcDimSz =
97  builder.create<DimOp>(loc, castOp.getSource(), it.index());
98  Value resultDimSz =
99  builder.create<arith::ConstantIndexOp>(loc, it.value());
100  Value isSameSz = builder.create<arith::CmpIOp>(
101  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
102  builder.create<cf::AssertOp>(
103  loc, isSameSz,
104  generateErrorMessage(op, "size mismatch of dim " +
105  std::to_string(it.index())));
106  }
107 
108  // Get result offset and strides.
109  int64_t resultOffset;
110  SmallVector<int64_t> resultStrides;
111  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
112  return;
113 
114  // Check offset.
115  if (resultOffset != ShapedType::kDynamic) {
116  // Static/dynamic offset -> dynamic offset does not need verification.
117  Value srcOffset = metadataOp.getResult(1);
118  Value resultOffsetVal =
119  builder.create<arith::ConstantIndexOp>(loc, resultOffset);
120  Value isSameOffset = builder.create<arith::CmpIOp>(
121  loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
122  builder.create<cf::AssertOp>(loc, isSameOffset,
123  generateErrorMessage(op, "offset mismatch"));
124  }
125 
126  // Check strides.
127  for (const auto &it : llvm::enumerate(resultStrides)) {
128  // Static/dynamic stride -> dynamic stride does not need verification.
129  if (it.value() == ShapedType::kDynamic)
130  continue;
131 
132  Value srcStride =
133  metadataOp.getResult(2 + resultType.getRank() + it.index());
134  Value resultStrideVal =
135  builder.create<arith::ConstantIndexOp>(loc, it.value());
136  Value isSameStride = builder.create<arith::CmpIOp>(
137  loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
138  builder.create<cf::AssertOp>(
139  loc, isSameStride,
140  generateErrorMessage(op, "stride mismatch of dim " +
141  std::to_string(it.index())));
142  }
143  }
144 };
145 
146 /// Verifies that the indices on load/store ops are in-bounds of the memref's
147 /// index space: 0 <= index#i < dim#i
148 template <typename LoadStoreOp>
149 struct LoadStoreOpInterface
150  : public RuntimeVerifiableOpInterface::ExternalModel<
151  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
152  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
153  Location loc) const {
154  auto loadStoreOp = cast<LoadStoreOp>(op);
155 
156  auto memref = loadStoreOp.getMemref();
157  auto rank = memref.getType().getRank();
158  if (rank == 0) {
159  return;
160  }
161  auto indices = loadStoreOp.getIndices();
162 
163  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
164  Value assertCond;
165  for (auto i : llvm::seq<int64_t>(0, rank)) {
166  auto index = indices[i];
167 
168  auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
169 
170  auto geLow = builder.createOrFold<arith::CmpIOp>(
171  loc, arith::CmpIPredicate::sge, index, zero);
172  auto ltHigh = builder.createOrFold<arith::CmpIOp>(
173  loc, arith::CmpIPredicate::slt, index, dimOp);
174  auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
175 
176  assertCond =
177  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
178  : andOp;
179  }
180  builder.create<cf::AssertOp>(
181  loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
182  }
183 };
184 
185 /// Compute the linear index for the provided strided layout and indices.
187  ArrayRef<OpFoldResult> strides,
188  ArrayRef<OpFoldResult> indices) {
189  auto [expr, values] = computeLinearIndex(offset, strides, indices);
190  auto index =
191  affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
192  return getValueOrCreateConstantIndexOp(builder, loc, index);
193 }
194 
195 /// Returns two Values representing the bounds of the provided strided layout
196 /// metadata. The bounds are returned as a half open interval -- [low, high).
197 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
198  OpFoldResult offset,
199  ArrayRef<OpFoldResult> strides,
200  ArrayRef<OpFoldResult> sizes) {
201  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
202  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
203  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
204  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
205  return {lowerBound, upperBound};
206 }
207 
208 /// Returns two Values representing the bounds of the memref. The bounds are
209 /// returned as a half open interval -- [low, high).
210 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
212  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
213  auto offset = runtimeMetadata.getConstifiedMixedOffset();
214  auto strides = runtimeMetadata.getConstifiedMixedStrides();
215  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
216  return computeLinearBounds(builder, loc, offset, strides, sizes);
217 }
218 
219 /// Verifies that the linear bounds of a reinterpret_cast op are within the
220 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
221 struct ReinterpretCastOpInterface
222  : public RuntimeVerifiableOpInterface::ExternalModel<
223  ReinterpretCastOpInterface, ReinterpretCastOp> {
224  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
225  Location loc) const {
226  auto reinterpretCast = cast<ReinterpretCastOp>(op);
227  auto baseMemref = reinterpretCast.getSource();
228  auto resultMemref =
229  cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
230 
231  builder.setInsertionPointAfter(op);
232 
233  // Compute the linear bounds of the base memref
234  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
235 
236  // Compute the linear bounds of the resulting memref
237  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
238 
239  // Check low >= baseLow
240  auto geLow = builder.createOrFold<arith::CmpIOp>(
241  loc, arith::CmpIPredicate::sge, low, baseLow);
242 
243  // Check high <= baseHigh
244  auto leHigh = builder.createOrFold<arith::CmpIOp>(
245  loc, arith::CmpIPredicate::sle, high, baseHigh);
246 
247  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
248 
249  builder.create<cf::AssertOp>(
250  loc, assertCond,
252  op,
253  "result of reinterpret_cast is out-of-bounds of the base memref"));
254  }
255 };
256 
257 /// Verifies that the linear bounds of a subview op are within the linear bounds
258 /// of the base memref: low >= baseLow && high <= baseHigh
259 /// TODO: This is not yet a full runtime verification of subview. For example,
260 /// consider:
261 /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
262 /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
263 /// : memref<?x?xf32> to memref<?x?xf32>
264 /// The subview is in-bounds of the entire base memref but the first dimension
265 /// is out-of-bounds. Future work would verify the bounds on a per-dimension
266 /// basis.
267 struct SubViewOpInterface
268  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
269  SubViewOp> {
270  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
271  Location loc) const {
272  auto subView = cast<SubViewOp>(op);
273  auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
274  auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
275 
276  builder.setInsertionPointAfter(op);
277 
278  // Compute the linear bounds of the base memref
279  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
280 
281  // Compute the linear bounds of the resulting memref
282  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
283 
284  // Check low >= baseLow
285  auto geLow = builder.createOrFold<arith::CmpIOp>(
286  loc, arith::CmpIPredicate::sge, low, baseLow);
287 
288  // Check high <= baseHigh
289  auto leHigh = builder.createOrFold<arith::CmpIOp>(
290  loc, arith::CmpIPredicate::sle, high, baseHigh);
291 
292  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
293 
294  builder.create<cf::AssertOp>(
295  loc, assertCond,
297  "subview is out-of-bounds of the base memref"));
298  }
299 };
300 
301 struct ExpandShapeOpInterface
302  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
303  ExpandShapeOp> {
304  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
305  Location loc) const {
306  auto expandShapeOp = cast<ExpandShapeOp>(op);
307 
308  // Verify that the expanded dim sizes are a product of the collapsed dim
309  // size.
310  for (const auto &it :
311  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
312  Value srcDimSz =
313  builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
314  int64_t groupSz = 1;
315  bool foundDynamicDim = false;
316  for (int64_t resultDim : it.value()) {
317  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
318  // Keep this assert here in case the op is extended in the future.
319  assert(!foundDynamicDim &&
320  "more than one dynamic dim found in reassoc group");
321  (void)foundDynamicDim;
322  foundDynamicDim = true;
323  continue;
324  }
325  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
326  }
327  Value staticResultDimSz =
328  builder.create<arith::ConstantIndexOp>(loc, groupSz);
329  // staticResultDimSz must divide srcDimSz evenly.
330  Value mod =
331  builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
332  Value isModZero = builder.create<arith::CmpIOp>(
333  loc, arith::CmpIPredicate::eq, mod,
334  builder.create<arith::ConstantIndexOp>(loc, 0));
335  builder.create<cf::AssertOp>(
336  loc, isModZero,
337  generateErrorMessage(op, "static result dims in reassoc group do not "
338  "divide src dim evenly"));
339  }
340  }
341 };
342 } // namespace
343 } // namespace memref
344 } // namespace mlir
345 
347  DialectRegistry &registry) {
348  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
349  CastOp::attachInterface<CastOpInterface>(*ctx);
350  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
351  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
352  ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
353  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
354  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
355 
356  // Load additional dialects of which ops may get created.
357  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
358  cf::ControlFlowDialect>();
359  });
360 }
static std::string generateErrorMessage(Operation *op, const std::string &msg)
Generate an error message string for the given op and the specified error.
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
void print(raw_ostream &os) const
Print the location.
Definition: Location.h:98
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & elideLargeElementsAttrs(int64_t largeElementLimit=16)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: AsmPrinter.cpp:234
OpPrintingFlags & printGenericOpForm(bool enable=true)
Always print operations in the generic form.
Definition: AsmPrinter.cpp:261
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:281
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45