MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 
20 using namespace mlir;
21 
22 namespace mlir {
23 namespace memref {
24 namespace {
25 /// Generate a runtime check for lb <= value < ub.
26 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
27  Value lb, Value ub) {
28  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
29  loc, arith::CmpIPredicate::sge, value, lb);
30  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
31  loc, arith::CmpIPredicate::slt, value, ub);
32  Value inBounds =
33  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
34  return inBounds;
35 }
36 
37 struct AssumeAlignmentOpInterface
38  : public RuntimeVerifiableOpInterface::ExternalModel<
39  AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
41  Location loc) const {
42  auto assumeOp = cast<AssumeAlignmentOp>(op);
43  Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
44  assumeOp.getMemref());
45  Value rest = arith::RemUIOp::create(
46  builder, loc, ptr,
47  arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
48  Value isAligned =
49  arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
50  arith::ConstantIndexOp::create(builder, loc, 0));
51  cf::AssertOp::create(builder, loc, isAligned,
52  RuntimeVerifiableOpInterface::generateErrorMessage(
53  op, "memref is not aligned to " +
54  std::to_string(assumeOp.getAlignment())));
55  }
56 };
57 
58 struct CastOpInterface
59  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
60  CastOp> {
61  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
62  Location loc) const {
63  auto castOp = cast<CastOp>(op);
64  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
65 
66  // Nothing to check if the result is an unranked memref.
67  auto resultType = dyn_cast<MemRefType>(castOp.getType());
68  if (!resultType)
69  return;
70 
71  if (isa<UnrankedMemRefType>(srcType)) {
72  // Check rank.
73  Value srcRank = RankOp::create(builder, loc, castOp.getSource());
74  Value resultRank =
75  arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
76  Value isSameRank = arith::CmpIOp::create(
77  builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
78  cf::AssertOp::create(builder, loc, isSameRank,
79  RuntimeVerifiableOpInterface::generateErrorMessage(
80  op, "rank mismatch"));
81  }
82 
83  // Get source offset and strides. We do not have an op to get offsets and
84  // strides from unranked memrefs, so cast the source to a type with fully
85  // dynamic layout, from which we can then extract the offset and strides.
86  // (Rank was already verified.)
87  int64_t dynamicOffset = ShapedType::kDynamic;
88  SmallVector<int64_t> dynamicShape(resultType.getRank(),
89  ShapedType::kDynamic);
90  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
91  dynamicOffset, dynamicShape);
92  auto dynStridesType =
93  MemRefType::get(dynamicShape, resultType.getElementType(),
94  stridedLayout, resultType.getMemorySpace());
95  Value helperCast =
96  CastOp::create(builder, loc, dynStridesType, castOp.getSource());
97  auto metadataOp =
98  ExtractStridedMetadataOp::create(builder, loc, helperCast);
99 
100  // Check dimension sizes.
101  for (const auto &it : llvm::enumerate(resultType.getShape())) {
102  // Static dim size -> static/dynamic dim size does not need verification.
103  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
104  if (!rankedSrcType.isDynamicDim(it.index()))
105  continue;
106 
107  // Static/dynamic dim size -> dynamic dim size does not need verification.
108  if (resultType.isDynamicDim(it.index()))
109  continue;
110 
111  Value srcDimSz =
112  DimOp::create(builder, loc, castOp.getSource(), it.index());
113  Value resultDimSz =
114  arith::ConstantIndexOp::create(builder, loc, it.value());
115  Value isSameSz = arith::CmpIOp::create(
116  builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117  cf::AssertOp::create(
118  builder, loc, isSameSz,
119  RuntimeVerifiableOpInterface::generateErrorMessage(
120  op, "size mismatch of dim " + std::to_string(it.index())));
121  }
122 
123  // Get result offset and strides.
124  int64_t resultOffset;
125  SmallVector<int64_t> resultStrides;
126  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
127  return;
128 
129  // Check offset.
130  if (resultOffset != ShapedType::kDynamic) {
131  // Static/dynamic offset -> dynamic offset does not need verification.
132  Value srcOffset = metadataOp.getResult(1);
133  Value resultOffsetVal =
134  arith::ConstantIndexOp::create(builder, loc, resultOffset);
135  Value isSameOffset = arith::CmpIOp::create(
136  builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137  cf::AssertOp::create(builder, loc, isSameOffset,
138  RuntimeVerifiableOpInterface::generateErrorMessage(
139  op, "offset mismatch"));
140  }
141 
142  // Check strides.
143  for (const auto &it : llvm::enumerate(resultStrides)) {
144  // Static/dynamic stride -> dynamic stride does not need verification.
145  if (it.value() == ShapedType::kDynamic)
146  continue;
147 
148  Value srcStride =
149  metadataOp.getResult(2 + resultType.getRank() + it.index());
150  Value resultStrideVal =
151  arith::ConstantIndexOp::create(builder, loc, it.value());
152  Value isSameStride = arith::CmpIOp::create(
153  builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154  cf::AssertOp::create(
155  builder, loc, isSameStride,
156  RuntimeVerifiableOpInterface::generateErrorMessage(
157  op, "stride mismatch of dim " + std::to_string(it.index())));
158  }
159  }
160 };
161 
162 struct CopyOpInterface
163  : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164  CopyOp> {
165  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
166  Location loc) const {
167  auto copyOp = cast<CopyOp>(op);
168  BaseMemRefType sourceType = copyOp.getSource().getType();
169  BaseMemRefType targetType = copyOp.getTarget().getType();
170  auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
171  auto rankedTargetType = dyn_cast<MemRefType>(targetType);
172 
173  // TODO: Verification for unranked memrefs is not supported yet.
174  if (!rankedSourceType || !rankedTargetType)
175  return;
176 
177  assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
178  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
179  // Fully static dimensions in both source and target operand are already
180  // verified by the op verifier.
181  if (!rankedSourceType.isDynamicDim(i) &&
182  !rankedTargetType.isDynamicDim(i))
183  continue;
184  auto getDimSize = [&](Value memRef, MemRefType type,
185  int64_t dim) -> Value {
186  return type.isDynamicDim(dim)
187  ? DimOp::create(builder, loc, memRef, dim).getResult()
188  : arith::ConstantIndexOp::create(builder, loc,
189  type.getDimSize(dim))
190  .getResult();
191  };
192  Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
193  Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
194  Value sameDimSize = arith::CmpIOp::create(
195  builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196  cf::AssertOp::create(builder, loc, sameDimSize,
197  RuntimeVerifiableOpInterface::generateErrorMessage(
198  op, "size of " + std::to_string(i) +
199  "-th source/target dim does not match"));
200  }
201  }
202 };
203 
204 struct DimOpInterface
205  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206  DimOp> {
207  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
208  Location loc) const {
209  auto dimOp = cast<DimOp>(op);
210  Value rank = RankOp::create(builder, loc, dimOp.getSource());
211  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
212  cf::AssertOp::create(
213  builder, loc,
214  generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
215  RuntimeVerifiableOpInterface::generateErrorMessage(
216  op, "index is out of bounds"));
217  }
218 };
219 
220 /// Verifies that the indices on load/store ops are in-bounds of the memref's
221 /// index space: 0 <= index#i < dim#i
222 template <typename LoadStoreOp>
223 struct LoadStoreOpInterface
224  : public RuntimeVerifiableOpInterface::ExternalModel<
225  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
227  Location loc) const {
228  auto loadStoreOp = cast<LoadStoreOp>(op);
229 
230  auto memref = loadStoreOp.getMemref();
231  auto rank = memref.getType().getRank();
232  if (rank == 0) {
233  return;
234  }
235  auto indices = loadStoreOp.getIndices();
236 
237  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
238  Value assertCond;
239  for (auto i : llvm::seq<int64_t>(0, rank)) {
240  Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
241  Value inBounds =
242  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
243  assertCond =
244  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
245  : inBounds;
246  }
247  cf::AssertOp::create(builder, loc, assertCond,
248  RuntimeVerifiableOpInterface::generateErrorMessage(
249  op, "out-of-bounds access"));
250  }
251 };
252 
253 struct SubViewOpInterface
254  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255  SubViewOp> {
256  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
257  Location loc) const {
258  auto subView = cast<SubViewOp>(op);
259  MemRefType sourceType = subView.getSource().getType();
260 
261  // For each dimension, assert that:
262  // 0 <= offset < dim_size
263  // 0 <= offset + (size - 1) * stride < dim_size
264  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
265  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
266  auto metadataOp =
267  ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
268  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
270  builder, loc, subView.getMixedOffsets()[i]);
271  Value size = getValueOrCreateConstantIndexOp(builder, loc,
272  subView.getMixedSizes()[i]);
274  builder, loc, subView.getMixedStrides()[i]);
275 
276  // Verify that offset is in-bounds.
277  Value dimSize = metadataOp.getSizes()[i];
278  Value offsetInBounds =
279  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
280  cf::AssertOp::create(
281  builder, loc, offsetInBounds,
282  RuntimeVerifiableOpInterface::generateErrorMessage(
283  op, "offset " + std::to_string(i) + " is out-of-bounds"));
284 
285  // Verify that slice does not run out-of-bounds.
286  Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
287  Value sizeMinusOneTimesStride =
288  arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
289  Value lastPos =
290  arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
291  Value lastPosInBounds =
292  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
293  cf::AssertOp::create(
294  builder, loc, lastPosInBounds,
295  RuntimeVerifiableOpInterface::generateErrorMessage(
296  op, "subview runs out-of-bounds along dimension " +
297  std::to_string(i)));
298  }
299  }
300 };
301 
302 struct ExpandShapeOpInterface
303  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304  ExpandShapeOp> {
305  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
306  Location loc) const {
307  auto expandShapeOp = cast<ExpandShapeOp>(op);
308 
309  // Verify that the expanded dim sizes are a product of the collapsed dim
310  // size.
311  for (const auto &it :
312  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
313  Value srcDimSz =
314  DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
315  int64_t groupSz = 1;
316  bool foundDynamicDim = false;
317  for (int64_t resultDim : it.value()) {
318  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
319  // Keep this assert here in case the op is extended in the future.
320  assert(!foundDynamicDim &&
321  "more than one dynamic dim found in reassoc group");
322  (void)foundDynamicDim;
323  foundDynamicDim = true;
324  continue;
325  }
326  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
327  }
328  Value staticResultDimSz =
329  arith::ConstantIndexOp::create(builder, loc, groupSz);
330  // staticResultDimSz must divide srcDimSz evenly.
331  Value mod =
332  arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
333  Value isModZero = arith::CmpIOp::create(
334  builder, loc, arith::CmpIPredicate::eq, mod,
335  arith::ConstantIndexOp::create(builder, loc, 0));
336  cf::AssertOp::create(builder, loc, isModZero,
337  RuntimeVerifiableOpInterface::generateErrorMessage(
338  op, "static result dims in reassoc group do not "
339  "divide src dim evenly"));
340  }
341  }
342 };
343 } // namespace
344 } // namespace memref
345 } // namespace mlir
346 
348  DialectRegistry &registry) {
349  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
350  AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
351  AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
352  CastOp::attachInterface<CastOpInterface>(*ctx);
353  CopyOp::attachInterface<CopyOpInterface>(*ctx);
354  DimOp::attachInterface<DimOpInterface>(*ctx);
355  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
356  GenericAtomicRMWOp::attachInterface<
357  LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
358  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
359  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
360  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
361  // Note: There is nothing to verify for ReinterpretCastOp.
362 
363  // Load additional dialects of which ops may get created.
364  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
365  cf::ControlFlowDialect>();
366  });
367 }
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...