MLIR  19.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 
21 using namespace mlir;
22 
23 namespace mlir {
24 namespace memref {
25 namespace {
26 struct CastOpInterface
27  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
28  CastOp> {
29  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
30  Location loc) const {
31  auto castOp = cast<CastOp>(op);
32  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
33 
34  // Nothing to check if the result is an unranked memref.
35  auto resultType = dyn_cast<MemRefType>(castOp.getType());
36  if (!resultType)
37  return;
38 
39  if (isa<UnrankedMemRefType>(srcType)) {
40  // Check rank.
41  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
42  Value resultRank =
43  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
44  Value isSameRank = builder.create<arith::CmpIOp>(
45  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46  builder.create<cf::AssertOp>(
47  loc, isSameRank,
48  RuntimeVerifiableOpInterface::generateErrorMessage(op,
49  "rank mismatch"));
50  }
51 
52  // Get source offset and strides. We do not have an op to get offsets and
53  // strides from unranked memrefs, so cast the source to a type with fully
54  // dynamic layout, from which we can then extract the offset and strides.
55  // (Rank was already verified.)
56  int64_t dynamicOffset = ShapedType::kDynamic;
57  SmallVector<int64_t> dynamicShape(resultType.getRank(),
58  ShapedType::kDynamic);
59  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
60  dynamicOffset, dynamicShape);
61  auto dynStridesType =
62  MemRefType::get(dynamicShape, resultType.getElementType(),
63  stridedLayout, resultType.getMemorySpace());
64  Value helperCast =
65  builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
66  auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
67 
68  // Check dimension sizes.
69  for (const auto &it : llvm::enumerate(resultType.getShape())) {
70  // Static dim size -> static/dynamic dim size does not need verification.
71  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
72  if (!rankedSrcType.isDynamicDim(it.index()))
73  continue;
74 
75  // Static/dynamic dim size -> dynamic dim size does not need verification.
76  if (resultType.isDynamicDim(it.index()))
77  continue;
78 
79  Value srcDimSz =
80  builder.create<DimOp>(loc, castOp.getSource(), it.index());
81  Value resultDimSz =
82  builder.create<arith::ConstantIndexOp>(loc, it.value());
83  Value isSameSz = builder.create<arith::CmpIOp>(
84  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
85  builder.create<cf::AssertOp>(
86  loc, isSameSz,
87  RuntimeVerifiableOpInterface::generateErrorMessage(
88  op, "size mismatch of dim " + std::to_string(it.index())));
89  }
90 
91  // Get result offset and strides.
92  int64_t resultOffset;
93  SmallVector<int64_t> resultStrides;
94  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
95  return;
96 
97  // Check offset.
98  if (resultOffset != ShapedType::kDynamic) {
99  // Static/dynamic offset -> dynamic offset does not need verification.
100  Value srcOffset = metadataOp.getResult(1);
101  Value resultOffsetVal =
102  builder.create<arith::ConstantIndexOp>(loc, resultOffset);
103  Value isSameOffset = builder.create<arith::CmpIOp>(
104  loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105  builder.create<cf::AssertOp>(
106  loc, isSameOffset,
107  RuntimeVerifiableOpInterface::generateErrorMessage(
108  op, "offset mismatch"));
109  }
110 
111  // Check strides.
112  for (const auto &it : llvm::enumerate(resultStrides)) {
113  // Static/dynamic stride -> dynamic stride does not need verification.
114  if (it.value() == ShapedType::kDynamic)
115  continue;
116 
117  Value srcStride =
118  metadataOp.getResult(2 + resultType.getRank() + it.index());
119  Value resultStrideVal =
120  builder.create<arith::ConstantIndexOp>(loc, it.value());
121  Value isSameStride = builder.create<arith::CmpIOp>(
122  loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
123  builder.create<cf::AssertOp>(
124  loc, isSameStride,
125  RuntimeVerifiableOpInterface::generateErrorMessage(
126  op, "stride mismatch of dim " + std::to_string(it.index())));
127  }
128  }
129 };
130 
131 /// Verifies that the indices on load/store ops are in-bounds of the memref's
132 /// index space: 0 <= index#i < dim#i
133 template <typename LoadStoreOp>
134 struct LoadStoreOpInterface
135  : public RuntimeVerifiableOpInterface::ExternalModel<
136  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
137  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
138  Location loc) const {
139  auto loadStoreOp = cast<LoadStoreOp>(op);
140 
141  auto memref = loadStoreOp.getMemref();
142  auto rank = memref.getType().getRank();
143  if (rank == 0) {
144  return;
145  }
146  auto indices = loadStoreOp.getIndices();
147 
148  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
149  Value assertCond;
150  for (auto i : llvm::seq<int64_t>(0, rank)) {
151  auto index = indices[i];
152 
153  auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
154 
155  auto geLow = builder.createOrFold<arith::CmpIOp>(
156  loc, arith::CmpIPredicate::sge, index, zero);
157  auto ltHigh = builder.createOrFold<arith::CmpIOp>(
158  loc, arith::CmpIPredicate::slt, index, dimOp);
159  auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
160 
161  assertCond =
162  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
163  : andOp;
164  }
165  builder.create<cf::AssertOp>(
166  loc, assertCond,
167  RuntimeVerifiableOpInterface::generateErrorMessage(
168  op, "out-of-bounds access"));
169  }
170 };
171 
172 /// Compute the linear index for the provided strided layout and indices.
174  ArrayRef<OpFoldResult> strides,
175  ArrayRef<OpFoldResult> indices) {
176  auto [expr, values] = computeLinearIndex(offset, strides, indices);
177  auto index =
178  affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
179  return getValueOrCreateConstantIndexOp(builder, loc, index);
180 }
181 
182 /// Returns two Values representing the bounds of the provided strided layout
183 /// metadata. The bounds are returned as a half open interval -- [low, high).
184 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
185  OpFoldResult offset,
186  ArrayRef<OpFoldResult> strides,
187  ArrayRef<OpFoldResult> sizes) {
188  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
189  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
190  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
191  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
192  return {lowerBound, upperBound};
193 }
194 
195 /// Returns two Values representing the bounds of the memref. The bounds are
196 /// returned as a half open interval -- [low, high).
197 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
199  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
200  auto offset = runtimeMetadata.getConstifiedMixedOffset();
201  auto strides = runtimeMetadata.getConstifiedMixedStrides();
202  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
203  return computeLinearBounds(builder, loc, offset, strides, sizes);
204 }
205 
206 /// Verifies that the linear bounds of a reinterpret_cast op are within the
207 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
208 struct ReinterpretCastOpInterface
209  : public RuntimeVerifiableOpInterface::ExternalModel<
210  ReinterpretCastOpInterface, ReinterpretCastOp> {
211  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
212  Location loc) const {
213  auto reinterpretCast = cast<ReinterpretCastOp>(op);
214  auto baseMemref = reinterpretCast.getSource();
215  auto resultMemref =
216  cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
217 
218  builder.setInsertionPointAfter(op);
219 
220  // Compute the linear bounds of the base memref
221  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
222 
223  // Compute the linear bounds of the resulting memref
224  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
225 
226  // Check low >= baseLow
227  auto geLow = builder.createOrFold<arith::CmpIOp>(
228  loc, arith::CmpIPredicate::sge, low, baseLow);
229 
230  // Check high <= baseHigh
231  auto leHigh = builder.createOrFold<arith::CmpIOp>(
232  loc, arith::CmpIPredicate::sle, high, baseHigh);
233 
234  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
235 
236  builder.create<cf::AssertOp>(
237  loc, assertCond,
238  RuntimeVerifiableOpInterface::generateErrorMessage(
239  op,
240  "result of reinterpret_cast is out-of-bounds of the base memref"));
241  }
242 };
243 
244 /// Verifies that the linear bounds of a subview op are within the linear bounds
245 /// of the base memref: low >= baseLow && high <= baseHigh
246 /// TODO: This is not yet a full runtime verification of subview. For example,
247 /// consider:
248 /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
249 /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
250 /// : memref<?x?xf32> to memref<?x?xf32>
251 /// The subview is in-bounds of the entire base memref but the first dimension
252 /// is out-of-bounds. Future work would verify the bounds on a per-dimension
253 /// basis.
254 struct SubViewOpInterface
255  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
256  SubViewOp> {
257  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
258  Location loc) const {
259  auto subView = cast<SubViewOp>(op);
260  auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
261  auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
262 
263  builder.setInsertionPointAfter(op);
264 
265  // Compute the linear bounds of the base memref
266  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
267 
268  // Compute the linear bounds of the resulting memref
269  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
270 
271  // Check low >= baseLow
272  auto geLow = builder.createOrFold<arith::CmpIOp>(
273  loc, arith::CmpIPredicate::sge, low, baseLow);
274 
275  // Check high <= baseHigh
276  auto leHigh = builder.createOrFold<arith::CmpIOp>(
277  loc, arith::CmpIPredicate::sle, high, baseHigh);
278 
279  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
280 
281  builder.create<cf::AssertOp>(
282  loc, assertCond,
283  RuntimeVerifiableOpInterface::generateErrorMessage(
284  op, "subview is out-of-bounds of the base memref"));
285  }
286 };
287 
288 struct ExpandShapeOpInterface
289  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
290  ExpandShapeOp> {
291  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
292  Location loc) const {
293  auto expandShapeOp = cast<ExpandShapeOp>(op);
294 
295  // Verify that the expanded dim sizes are a product of the collapsed dim
296  // size.
297  for (const auto &it :
298  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
299  Value srcDimSz =
300  builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
301  int64_t groupSz = 1;
302  bool foundDynamicDim = false;
303  for (int64_t resultDim : it.value()) {
304  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
305  // Keep this assert here in case the op is extended in the future.
306  assert(!foundDynamicDim &&
307  "more than one dynamic dim found in reassoc group");
308  (void)foundDynamicDim;
309  foundDynamicDim = true;
310  continue;
311  }
312  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
313  }
314  Value staticResultDimSz =
315  builder.create<arith::ConstantIndexOp>(loc, groupSz);
316  // staticResultDimSz must divide srcDimSz evenly.
317  Value mod =
318  builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
319  Value isModZero = builder.create<arith::CmpIOp>(
320  loc, arith::CmpIPredicate::eq, mod,
321  builder.create<arith::ConstantIndexOp>(loc, 0));
322  builder.create<cf::AssertOp>(
323  loc, isModZero,
324  RuntimeVerifiableOpInterface::generateErrorMessage(
325  op, "static result dims in reassoc group do not "
326  "divide src dim evenly"));
327  }
328  }
329 };
330 } // namespace
331 } // namespace memref
332 } // namespace mlir
333 
335  DialectRegistry &registry) {
336  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
337  CastOp::attachInterface<CastOpInterface>(*ctx);
338  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340  ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
341  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
342  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
343 
344  // Load additional dialects of which ops may get created.
345  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
346  cf::ControlFlowDialect>();
347  });
348 }
MLIRContext * getContext() const
Definition: Builders.h:55
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...