MLIR  21.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 
21 using namespace mlir;
22 
23 namespace mlir {
24 namespace memref {
25 namespace {
26 struct CastOpInterface
27  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
28  CastOp> {
29  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
30  Location loc) const {
31  auto castOp = cast<CastOp>(op);
32  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
33 
34  // Nothing to check if the result is an unranked memref.
35  auto resultType = dyn_cast<MemRefType>(castOp.getType());
36  if (!resultType)
37  return;
38 
39  if (isa<UnrankedMemRefType>(srcType)) {
40  // Check rank.
41  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
42  Value resultRank =
43  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
44  Value isSameRank = builder.create<arith::CmpIOp>(
45  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46  builder.create<cf::AssertOp>(
47  loc, isSameRank,
48  RuntimeVerifiableOpInterface::generateErrorMessage(op,
49  "rank mismatch"));
50  }
51 
52  // Get source offset and strides. We do not have an op to get offsets and
53  // strides from unranked memrefs, so cast the source to a type with fully
54  // dynamic layout, from which we can then extract the offset and strides.
55  // (Rank was already verified.)
56  int64_t dynamicOffset = ShapedType::kDynamic;
57  SmallVector<int64_t> dynamicShape(resultType.getRank(),
58  ShapedType::kDynamic);
59  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
60  dynamicOffset, dynamicShape);
61  auto dynStridesType =
62  MemRefType::get(dynamicShape, resultType.getElementType(),
63  stridedLayout, resultType.getMemorySpace());
64  Value helperCast =
65  builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
66  auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
67 
68  // Check dimension sizes.
69  for (const auto &it : llvm::enumerate(resultType.getShape())) {
70  // Static dim size -> static/dynamic dim size does not need verification.
71  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
72  if (!rankedSrcType.isDynamicDim(it.index()))
73  continue;
74 
75  // Static/dynamic dim size -> dynamic dim size does not need verification.
76  if (resultType.isDynamicDim(it.index()))
77  continue;
78 
79  Value srcDimSz =
80  builder.create<DimOp>(loc, castOp.getSource(), it.index());
81  Value resultDimSz =
82  builder.create<arith::ConstantIndexOp>(loc, it.value());
83  Value isSameSz = builder.create<arith::CmpIOp>(
84  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
85  builder.create<cf::AssertOp>(
86  loc, isSameSz,
87  RuntimeVerifiableOpInterface::generateErrorMessage(
88  op, "size mismatch of dim " + std::to_string(it.index())));
89  }
90 
91  // Get result offset and strides.
92  int64_t resultOffset;
93  SmallVector<int64_t> resultStrides;
94  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
95  return;
96 
97  // Check offset.
98  if (resultOffset != ShapedType::kDynamic) {
99  // Static/dynamic offset -> dynamic offset does not need verification.
100  Value srcOffset = metadataOp.getResult(1);
101  Value resultOffsetVal =
102  builder.create<arith::ConstantIndexOp>(loc, resultOffset);
103  Value isSameOffset = builder.create<arith::CmpIOp>(
104  loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105  builder.create<cf::AssertOp>(
106  loc, isSameOffset,
107  RuntimeVerifiableOpInterface::generateErrorMessage(
108  op, "offset mismatch"));
109  }
110 
111  // Check strides.
112  for (const auto &it : llvm::enumerate(resultStrides)) {
113  // Static/dynamic stride -> dynamic stride does not need verification.
114  if (it.value() == ShapedType::kDynamic)
115  continue;
116 
117  Value srcStride =
118  metadataOp.getResult(2 + resultType.getRank() + it.index());
119  Value resultStrideVal =
120  builder.create<arith::ConstantIndexOp>(loc, it.value());
121  Value isSameStride = builder.create<arith::CmpIOp>(
122  loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
123  builder.create<cf::AssertOp>(
124  loc, isSameStride,
125  RuntimeVerifiableOpInterface::generateErrorMessage(
126  op, "stride mismatch of dim " + std::to_string(it.index())));
127  }
128  }
129 };
130 
131 struct CopyOpInterface
132  : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
133  CopyOp> {
134  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
135  Location loc) const {
136  auto copyOp = cast<CopyOp>(op);
137  BaseMemRefType sourceType = copyOp.getSource().getType();
138  BaseMemRefType targetType = copyOp.getTarget().getType();
139  auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
140  auto rankedTargetType = dyn_cast<MemRefType>(targetType);
141 
142  // TODO: Verification for unranked memrefs is not supported yet.
143  if (!rankedSourceType || !rankedTargetType)
144  return;
145 
146  assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
147  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
148  // Fully static dimensions in both source and target operand are already
149  // verified by the op verifier.
150  if (!rankedSourceType.isDynamicDim(i) &&
151  !rankedTargetType.isDynamicDim(i))
152  continue;
153  auto getDimSize = [&](Value memRef, MemRefType type,
154  int64_t dim) -> Value {
155  return type.isDynamicDim(dim)
156  ? builder.create<DimOp>(loc, memRef, dim).getResult()
157  : builder
158  .create<arith::ConstantIndexOp>(loc,
159  type.getDimSize(dim))
160  .getResult();
161  };
162  Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
163  Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
164  Value sameDimSize = builder.create<arith::CmpIOp>(
165  loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
166  builder.create<cf::AssertOp>(
167  loc, sameDimSize,
168  RuntimeVerifiableOpInterface::generateErrorMessage(
169  op, "size of " + std::to_string(i) +
170  "-th source/target dim does not match"));
171  }
172  }
173 };
174 
175 /// Verifies that the indices on load/store ops are in-bounds of the memref's
176 /// index space: 0 <= index#i < dim#i
177 template <typename LoadStoreOp>
178 struct LoadStoreOpInterface
179  : public RuntimeVerifiableOpInterface::ExternalModel<
180  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
181  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
182  Location loc) const {
183  auto loadStoreOp = cast<LoadStoreOp>(op);
184 
185  auto memref = loadStoreOp.getMemref();
186  auto rank = memref.getType().getRank();
187  if (rank == 0) {
188  return;
189  }
190  auto indices = loadStoreOp.getIndices();
191 
192  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
193  Value assertCond;
194  for (auto i : llvm::seq<int64_t>(0, rank)) {
195  auto index = indices[i];
196 
197  auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
198 
199  auto geLow = builder.createOrFold<arith::CmpIOp>(
200  loc, arith::CmpIPredicate::sge, index, zero);
201  auto ltHigh = builder.createOrFold<arith::CmpIOp>(
202  loc, arith::CmpIPredicate::slt, index, dimOp);
203  auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
204 
205  assertCond =
206  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
207  : andOp;
208  }
209  builder.create<cf::AssertOp>(
210  loc, assertCond,
211  RuntimeVerifiableOpInterface::generateErrorMessage(
212  op, "out-of-bounds access"));
213  }
214 };
215 
216 /// Compute the linear index for the provided strided layout and indices.
218  ArrayRef<OpFoldResult> strides,
219  ArrayRef<OpFoldResult> indices) {
220  auto [expr, values] = computeLinearIndex(offset, strides, indices);
221  auto index =
222  affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
223  return getValueOrCreateConstantIndexOp(builder, loc, index);
224 }
225 
226 /// Returns two Values representing the bounds of the provided strided layout
227 /// metadata. The bounds are returned as a half open interval -- [low, high).
228 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
229  OpFoldResult offset,
230  ArrayRef<OpFoldResult> strides,
231  ArrayRef<OpFoldResult> sizes) {
232  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
233  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
234  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
235  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
236  return {lowerBound, upperBound};
237 }
238 
239 /// Returns two Values representing the bounds of the memref. The bounds are
240 /// returned as a half open interval -- [low, high).
241 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
243  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
244  auto offset = runtimeMetadata.getConstifiedMixedOffset();
245  auto strides = runtimeMetadata.getConstifiedMixedStrides();
246  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
247  return computeLinearBounds(builder, loc, offset, strides, sizes);
248 }
249 
250 /// Verifies that the linear bounds of a reinterpret_cast op are within the
251 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
252 struct ReinterpretCastOpInterface
253  : public RuntimeVerifiableOpInterface::ExternalModel<
254  ReinterpretCastOpInterface, ReinterpretCastOp> {
255  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
256  Location loc) const {
257  auto reinterpretCast = cast<ReinterpretCastOp>(op);
258  auto baseMemref = reinterpretCast.getSource();
259  auto resultMemref =
260  cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
261 
262  builder.setInsertionPointAfter(op);
263 
264  // Compute the linear bounds of the base memref
265  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
266 
267  // Compute the linear bounds of the resulting memref
268  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
269 
270  // Check low >= baseLow
271  auto geLow = builder.createOrFold<arith::CmpIOp>(
272  loc, arith::CmpIPredicate::sge, low, baseLow);
273 
274  // Check high <= baseHigh
275  auto leHigh = builder.createOrFold<arith::CmpIOp>(
276  loc, arith::CmpIPredicate::sle, high, baseHigh);
277 
278  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
279 
280  builder.create<cf::AssertOp>(
281  loc, assertCond,
282  RuntimeVerifiableOpInterface::generateErrorMessage(
283  op,
284  "result of reinterpret_cast is out-of-bounds of the base memref"));
285  }
286 };
287 
288 /// Verifies that the linear bounds of a subview op are within the linear bounds
289 /// of the base memref: low >= baseLow && high <= baseHigh
290 /// TODO: This is not yet a full runtime verification of subview. For example,
291 /// consider:
292 /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
293 /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
294 /// : memref<?x?xf32> to memref<?x?xf32>
295 /// The subview is in-bounds of the entire base memref but the first dimension
296 /// is out-of-bounds. Future work would verify the bounds on a per-dimension
297 /// basis.
298 struct SubViewOpInterface
299  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
300  SubViewOp> {
301  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
302  Location loc) const {
303  auto subView = cast<SubViewOp>(op);
304  auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
305  auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
306 
307  builder.setInsertionPointAfter(op);
308 
309  // Compute the linear bounds of the base memref
310  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
311 
312  // Compute the linear bounds of the resulting memref
313  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
314 
315  // Check low >= baseLow
316  auto geLow = builder.createOrFold<arith::CmpIOp>(
317  loc, arith::CmpIPredicate::sge, low, baseLow);
318 
319  // Check high <= baseHigh
320  auto leHigh = builder.createOrFold<arith::CmpIOp>(
321  loc, arith::CmpIPredicate::sle, high, baseHigh);
322 
323  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
324 
325  builder.create<cf::AssertOp>(
326  loc, assertCond,
327  RuntimeVerifiableOpInterface::generateErrorMessage(
328  op, "subview is out-of-bounds of the base memref"));
329  }
330 };
331 
332 struct ExpandShapeOpInterface
333  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
334  ExpandShapeOp> {
335  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
336  Location loc) const {
337  auto expandShapeOp = cast<ExpandShapeOp>(op);
338 
339  // Verify that the expanded dim sizes are a product of the collapsed dim
340  // size.
341  for (const auto &it :
342  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
343  Value srcDimSz =
344  builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
345  int64_t groupSz = 1;
346  bool foundDynamicDim = false;
347  for (int64_t resultDim : it.value()) {
348  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
349  // Keep this assert here in case the op is extended in the future.
350  assert(!foundDynamicDim &&
351  "more than one dynamic dim found in reassoc group");
352  (void)foundDynamicDim;
353  foundDynamicDim = true;
354  continue;
355  }
356  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
357  }
358  Value staticResultDimSz =
359  builder.create<arith::ConstantIndexOp>(loc, groupSz);
360  // staticResultDimSz must divide srcDimSz evenly.
361  Value mod =
362  builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
363  Value isModZero = builder.create<arith::CmpIOp>(
364  loc, arith::CmpIPredicate::eq, mod,
365  builder.create<arith::ConstantIndexOp>(loc, 0));
366  builder.create<cf::AssertOp>(
367  loc, isModZero,
368  RuntimeVerifiableOpInterface::generateErrorMessage(
369  op, "static result dims in reassoc group do not "
370  "divide src dim evenly"));
371  }
372  }
373 };
374 } // namespace
375 } // namespace memref
376 } // namespace mlir
377 
379  DialectRegistry &registry) {
380  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
381  CastOp::attachInterface<CastOpInterface>(*ctx);
382  CopyOp::attachInterface<CopyOpInterface>(*ctx);
383  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
384  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
385  ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
386  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
387  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
388 
389  // Load additional dialects of which ops may get created.
390  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
391  cf::ControlFlowDialect>();
392  });
393 }
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...