MLIR  21.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 
21 using namespace mlir;
22 
23 namespace mlir {
24 namespace memref {
25 namespace {
26 /// Generate a runtime check for lb <= value < ub.
27 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28  Value lb, Value ub) {
29  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
30  loc, arith::CmpIPredicate::sge, value, lb);
31  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
32  loc, arith::CmpIPredicate::slt, value, ub);
33  Value inBounds =
34  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35  return inBounds;
36 }
37 
38 struct AssumeAlignmentOpInterface
39  : public RuntimeVerifiableOpInterface::ExternalModel<
40  AssumeAlignmentOpInterface, AssumeAlignmentOp> {
41  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
42  Location loc) const {
43  auto assumeOp = cast<AssumeAlignmentOp>(op);
44  Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
45  loc, assumeOp.getMemref());
46  Value rest = builder.create<arith::RemUIOp>(
47  loc, ptr,
48  builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
49  Value isAligned = builder.create<arith::CmpIOp>(
50  loc, arith::CmpIPredicate::eq, rest,
51  builder.create<arith::ConstantIndexOp>(loc, 0));
52  builder.create<cf::AssertOp>(
53  loc, isAligned,
54  RuntimeVerifiableOpInterface::generateErrorMessage(
55  op, "memref is not aligned to " +
56  std::to_string(assumeOp.getAlignment())));
57  }
58 };
59 
60 struct CastOpInterface
61  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
62  CastOp> {
63  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
64  Location loc) const {
65  auto castOp = cast<CastOp>(op);
66  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
67 
68  // Nothing to check if the result is an unranked memref.
69  auto resultType = dyn_cast<MemRefType>(castOp.getType());
70  if (!resultType)
71  return;
72 
73  if (isa<UnrankedMemRefType>(srcType)) {
74  // Check rank.
75  Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
76  Value resultRank =
77  builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
78  Value isSameRank = builder.create<arith::CmpIOp>(
79  loc, arith::CmpIPredicate::eq, srcRank, resultRank);
80  builder.create<cf::AssertOp>(
81  loc, isSameRank,
82  RuntimeVerifiableOpInterface::generateErrorMessage(op,
83  "rank mismatch"));
84  }
85 
86  // Get source offset and strides. We do not have an op to get offsets and
87  // strides from unranked memrefs, so cast the source to a type with fully
88  // dynamic layout, from which we can then extract the offset and strides.
89  // (Rank was already verified.)
90  int64_t dynamicOffset = ShapedType::kDynamic;
91  SmallVector<int64_t> dynamicShape(resultType.getRank(),
92  ShapedType::kDynamic);
93  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
94  dynamicOffset, dynamicShape);
95  auto dynStridesType =
96  MemRefType::get(dynamicShape, resultType.getElementType(),
97  stridedLayout, resultType.getMemorySpace());
98  Value helperCast =
99  builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
100  auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
101 
102  // Check dimension sizes.
103  for (const auto &it : llvm::enumerate(resultType.getShape())) {
104  // Static dim size -> static/dynamic dim size does not need verification.
105  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
106  if (!rankedSrcType.isDynamicDim(it.index()))
107  continue;
108 
109  // Static/dynamic dim size -> dynamic dim size does not need verification.
110  if (resultType.isDynamicDim(it.index()))
111  continue;
112 
113  Value srcDimSz =
114  builder.create<DimOp>(loc, castOp.getSource(), it.index());
115  Value resultDimSz =
116  builder.create<arith::ConstantIndexOp>(loc, it.value());
117  Value isSameSz = builder.create<arith::CmpIOp>(
118  loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
119  builder.create<cf::AssertOp>(
120  loc, isSameSz,
121  RuntimeVerifiableOpInterface::generateErrorMessage(
122  op, "size mismatch of dim " + std::to_string(it.index())));
123  }
124 
125  // Get result offset and strides.
126  int64_t resultOffset;
127  SmallVector<int64_t> resultStrides;
128  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
129  return;
130 
131  // Check offset.
132  if (resultOffset != ShapedType::kDynamic) {
133  // Static/dynamic offset -> dynamic offset does not need verification.
134  Value srcOffset = metadataOp.getResult(1);
135  Value resultOffsetVal =
136  builder.create<arith::ConstantIndexOp>(loc, resultOffset);
137  Value isSameOffset = builder.create<arith::CmpIOp>(
138  loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
139  builder.create<cf::AssertOp>(
140  loc, isSameOffset,
141  RuntimeVerifiableOpInterface::generateErrorMessage(
142  op, "offset mismatch"));
143  }
144 
145  // Check strides.
146  for (const auto &it : llvm::enumerate(resultStrides)) {
147  // Static/dynamic stride -> dynamic stride does not need verification.
148  if (it.value() == ShapedType::kDynamic)
149  continue;
150 
151  Value srcStride =
152  metadataOp.getResult(2 + resultType.getRank() + it.index());
153  Value resultStrideVal =
154  builder.create<arith::ConstantIndexOp>(loc, it.value());
155  Value isSameStride = builder.create<arith::CmpIOp>(
156  loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157  builder.create<cf::AssertOp>(
158  loc, isSameStride,
159  RuntimeVerifiableOpInterface::generateErrorMessage(
160  op, "stride mismatch of dim " + std::to_string(it.index())));
161  }
162  }
163 };
164 
165 struct CopyOpInterface
166  : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
167  CopyOp> {
168  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
169  Location loc) const {
170  auto copyOp = cast<CopyOp>(op);
171  BaseMemRefType sourceType = copyOp.getSource().getType();
172  BaseMemRefType targetType = copyOp.getTarget().getType();
173  auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
174  auto rankedTargetType = dyn_cast<MemRefType>(targetType);
175 
176  // TODO: Verification for unranked memrefs is not supported yet.
177  if (!rankedSourceType || !rankedTargetType)
178  return;
179 
180  assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
181  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
182  // Fully static dimensions in both source and target operand are already
183  // verified by the op verifier.
184  if (!rankedSourceType.isDynamicDim(i) &&
185  !rankedTargetType.isDynamicDim(i))
186  continue;
187  auto getDimSize = [&](Value memRef, MemRefType type,
188  int64_t dim) -> Value {
189  return type.isDynamicDim(dim)
190  ? builder.create<DimOp>(loc, memRef, dim).getResult()
191  : builder
192  .create<arith::ConstantIndexOp>(loc,
193  type.getDimSize(dim))
194  .getResult();
195  };
196  Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197  Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
198  Value sameDimSize = builder.create<arith::CmpIOp>(
199  loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200  builder.create<cf::AssertOp>(
201  loc, sameDimSize,
202  RuntimeVerifiableOpInterface::generateErrorMessage(
203  op, "size of " + std::to_string(i) +
204  "-th source/target dim does not match"));
205  }
206  }
207 };
208 
209 struct DimOpInterface
210  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
211  DimOp> {
212  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
213  Location loc) const {
214  auto dimOp = cast<DimOp>(op);
215  Value rank = builder.create<RankOp>(loc, dimOp.getSource());
216  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
217  builder.create<cf::AssertOp>(
218  loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
219  RuntimeVerifiableOpInterface::generateErrorMessage(
220  op, "index is out of bounds"));
221  }
222 };
223 
224 /// Verifies that the indices on load/store ops are in-bounds of the memref's
225 /// index space: 0 <= index#i < dim#i
226 template <typename LoadStoreOp>
227 struct LoadStoreOpInterface
228  : public RuntimeVerifiableOpInterface::ExternalModel<
229  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
230  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
231  Location loc) const {
232  auto loadStoreOp = cast<LoadStoreOp>(op);
233 
234  auto memref = loadStoreOp.getMemref();
235  auto rank = memref.getType().getRank();
236  if (rank == 0) {
237  return;
238  }
239  auto indices = loadStoreOp.getIndices();
240 
241  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
242  Value assertCond;
243  for (auto i : llvm::seq<int64_t>(0, rank)) {
244  Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
245  Value inBounds =
246  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
247  assertCond =
248  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
249  : inBounds;
250  }
251  builder.create<cf::AssertOp>(
252  loc, assertCond,
253  RuntimeVerifiableOpInterface::generateErrorMessage(
254  op, "out-of-bounds access"));
255  }
256 };
257 
258 /// Compute the linear index for the provided strided layout and indices.
260  ArrayRef<OpFoldResult> strides,
261  ArrayRef<OpFoldResult> indices) {
262  auto [expr, values] = computeLinearIndex(offset, strides, indices);
263  auto index =
264  affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
265  return getValueOrCreateConstantIndexOp(builder, loc, index);
266 }
267 
268 /// Returns two Values representing the bounds of the provided strided layout
269 /// metadata. The bounds are returned as a half open interval -- [low, high).
270 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
271  OpFoldResult offset,
272  ArrayRef<OpFoldResult> strides,
273  ArrayRef<OpFoldResult> sizes) {
274  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
275  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
276  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
277  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
278  return {lowerBound, upperBound};
279 }
280 
281 /// Returns two Values representing the bounds of the memref. The bounds are
282 /// returned as a half open interval -- [low, high).
283 std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
285  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
286  auto offset = runtimeMetadata.getConstifiedMixedOffset();
287  auto strides = runtimeMetadata.getConstifiedMixedStrides();
288  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
289  return computeLinearBounds(builder, loc, offset, strides, sizes);
290 }
291 
292 /// Verifies that the linear bounds of a reinterpret_cast op are within the
293 /// linear bounds of the base memref: low >= baseLow && high <= baseHigh
294 struct ReinterpretCastOpInterface
295  : public RuntimeVerifiableOpInterface::ExternalModel<
296  ReinterpretCastOpInterface, ReinterpretCastOp> {
297  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
298  Location loc) const {
299  auto reinterpretCast = cast<ReinterpretCastOp>(op);
300  auto baseMemref = reinterpretCast.getSource();
301  auto resultMemref =
302  cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
303 
304  builder.setInsertionPointAfter(op);
305 
306  // Compute the linear bounds of the base memref
307  auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
308 
309  // Compute the linear bounds of the resulting memref
310  auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
311 
312  // Check low >= baseLow
313  auto geLow = builder.createOrFold<arith::CmpIOp>(
314  loc, arith::CmpIPredicate::sge, low, baseLow);
315 
316  // Check high <= baseHigh
317  auto leHigh = builder.createOrFold<arith::CmpIOp>(
318  loc, arith::CmpIPredicate::sle, high, baseHigh);
319 
320  auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
321 
322  builder.create<cf::AssertOp>(
323  loc, assertCond,
324  RuntimeVerifiableOpInterface::generateErrorMessage(
325  op,
326  "result of reinterpret_cast is out-of-bounds of the base memref"));
327  }
328 };
329 
330 struct SubViewOpInterface
331  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
332  SubViewOp> {
333  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
334  Location loc) const {
335  auto subView = cast<SubViewOp>(op);
336  MemRefType sourceType = subView.getSource().getType();
337 
338  // For each dimension, assert that:
339  // 0 <= offset < dim_size
340  // 0 <= offset + (size - 1) * stride < dim_size
341  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
342  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
343  auto metadataOp =
344  builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
345  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
347  builder, loc, subView.getMixedOffsets()[i]);
348  Value size = getValueOrCreateConstantIndexOp(builder, loc,
349  subView.getMixedSizes()[i]);
351  builder, loc, subView.getMixedStrides()[i]);
352 
353  // Verify that offset is in-bounds.
354  Value dimSize = metadataOp.getSizes()[i];
355  Value offsetInBounds =
356  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
357  builder.create<cf::AssertOp>(
358  loc, offsetInBounds,
359  RuntimeVerifiableOpInterface::generateErrorMessage(
360  op, "offset " + std::to_string(i) + " is out-of-bounds"));
361 
362  // Verify that slice does not run out-of-bounds.
363  Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
364  Value sizeMinusOneTimesStride =
365  builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
366  Value lastPos =
367  builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
368  Value lastPosInBounds =
369  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
370  builder.create<cf::AssertOp>(
371  loc, lastPosInBounds,
372  RuntimeVerifiableOpInterface::generateErrorMessage(
373  op, "subview runs out-of-bounds along dimension " +
374  std::to_string(i)));
375  }
376  }
377 };
378 
379 struct ExpandShapeOpInterface
380  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
381  ExpandShapeOp> {
382  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
383  Location loc) const {
384  auto expandShapeOp = cast<ExpandShapeOp>(op);
385 
386  // Verify that the expanded dim sizes are a product of the collapsed dim
387  // size.
388  for (const auto &it :
389  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
390  Value srcDimSz =
391  builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
392  int64_t groupSz = 1;
393  bool foundDynamicDim = false;
394  for (int64_t resultDim : it.value()) {
395  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
396  // Keep this assert here in case the op is extended in the future.
397  assert(!foundDynamicDim &&
398  "more than one dynamic dim found in reassoc group");
399  (void)foundDynamicDim;
400  foundDynamicDim = true;
401  continue;
402  }
403  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
404  }
405  Value staticResultDimSz =
406  builder.create<arith::ConstantIndexOp>(loc, groupSz);
407  // staticResultDimSz must divide srcDimSz evenly.
408  Value mod =
409  builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
410  Value isModZero = builder.create<arith::CmpIOp>(
411  loc, arith::CmpIPredicate::eq, mod,
412  builder.create<arith::ConstantIndexOp>(loc, 0));
413  builder.create<cf::AssertOp>(
414  loc, isModZero,
415  RuntimeVerifiableOpInterface::generateErrorMessage(
416  op, "static result dims in reassoc group do not "
417  "divide src dim evenly"));
418  }
419  }
420 };
421 } // namespace
422 } // namespace memref
423 } // namespace mlir
424 
426  DialectRegistry &registry) {
427  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
428  AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
429  CastOp::attachInterface<CastOpInterface>(*ctx);
430  CopyOp::attachInterface<CopyOpInterface>(*ctx);
431  DimOp::attachInterface<DimOpInterface>(*ctx);
432  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
433  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
434  ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
435  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
436  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
437 
438  // Load additional dialects of which ops may get created.
439  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
440  cf::ControlFlowDialect>();
441  });
442 }
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1217
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...