MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 
20 using namespace mlir;
21 
22 namespace mlir {
23 namespace memref {
24 namespace {
25 /// Generate a runtime check for lb <= value < ub.
26 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
27  Value lb, Value ub) {
28  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
29  loc, arith::CmpIPredicate::sge, value, lb);
30  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
31  loc, arith::CmpIPredicate::slt, value, ub);
32  Value inBounds =
33  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
34  return inBounds;
35 }
36 
37 struct AssumeAlignmentOpInterface
38  : public RuntimeVerifiableOpInterface::ExternalModel<
39  AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40  void
41  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
42  function_ref<std::string(Operation *, StringRef)>
43  generateErrorMessage) const {
44  auto assumeOp = cast<AssumeAlignmentOp>(op);
45  Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
46  assumeOp.getMemref());
47  Value rest = arith::RemUIOp::create(
48  builder, loc, ptr,
49  arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
50  Value isAligned =
51  arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
52  arith::ConstantIndexOp::create(builder, loc, 0));
53  cf::AssertOp::create(
54  builder, loc, isAligned,
55  generateErrorMessage(op, "memref is not aligned to " +
56  std::to_string(assumeOp.getAlignment())));
57  }
58 };
59 
60 struct CastOpInterface
61  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
62  CastOp> {
63  void
64  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
65  function_ref<std::string(Operation *, StringRef)>
66  generateErrorMessage) const {
67  auto castOp = cast<CastOp>(op);
68  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
69 
70  // Nothing to check if the result is an unranked memref.
71  auto resultType = dyn_cast<MemRefType>(castOp.getType());
72  if (!resultType)
73  return;
74 
75  if (isa<UnrankedMemRefType>(srcType)) {
76  // Check rank.
77  Value srcRank = RankOp::create(builder, loc, castOp.getSource());
78  Value resultRank =
79  arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
80  Value isSameRank = arith::CmpIOp::create(
81  builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
82  cf::AssertOp::create(builder, loc, isSameRank,
83  generateErrorMessage(op, "rank mismatch"));
84  }
85 
86  // Get source offset and strides. We do not have an op to get offsets and
87  // strides from unranked memrefs, so cast the source to a type with fully
88  // dynamic layout, from which we can then extract the offset and strides.
89  // (Rank was already verified.)
90  int64_t dynamicOffset = ShapedType::kDynamic;
91  SmallVector<int64_t> dynamicShape(resultType.getRank(),
92  ShapedType::kDynamic);
93  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
94  dynamicOffset, dynamicShape);
95  auto dynStridesType =
96  MemRefType::get(dynamicShape, resultType.getElementType(),
97  stridedLayout, resultType.getMemorySpace());
98  Value helperCast =
99  CastOp::create(builder, loc, dynStridesType, castOp.getSource());
100  auto metadataOp =
101  ExtractStridedMetadataOp::create(builder, loc, helperCast);
102 
103  // Check dimension sizes.
104  for (const auto &it : llvm::enumerate(resultType.getShape())) {
105  // Static dim size -> static/dynamic dim size does not need verification.
106  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
107  if (!rankedSrcType.isDynamicDim(it.index()))
108  continue;
109 
110  // Static/dynamic dim size -> dynamic dim size does not need verification.
111  if (resultType.isDynamicDim(it.index()))
112  continue;
113 
114  Value srcDimSz =
115  DimOp::create(builder, loc, castOp.getSource(), it.index());
116  Value resultDimSz =
117  arith::ConstantIndexOp::create(builder, loc, it.value());
118  Value isSameSz = arith::CmpIOp::create(
119  builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
120  cf::AssertOp::create(
121  builder, loc, isSameSz,
122  generateErrorMessage(op, "size mismatch of dim " +
123  std::to_string(it.index())));
124  }
125 
126  // Get result offset and strides.
127  int64_t resultOffset;
128  SmallVector<int64_t> resultStrides;
129  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
130  return;
131 
132  // Check offset.
133  if (resultOffset != ShapedType::kDynamic) {
134  // Static/dynamic offset -> dynamic offset does not need verification.
135  Value srcOffset = metadataOp.getResult(1);
136  Value resultOffsetVal =
137  arith::ConstantIndexOp::create(builder, loc, resultOffset);
138  Value isSameOffset = arith::CmpIOp::create(
139  builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
140  cf::AssertOp::create(builder, loc, isSameOffset,
141  generateErrorMessage(op, "offset mismatch"));
142  }
143 
144  // Check strides.
145  for (const auto &it : llvm::enumerate(resultStrides)) {
146  // Static/dynamic stride -> dynamic stride does not need verification.
147  if (it.value() == ShapedType::kDynamic)
148  continue;
149 
150  Value srcStride =
151  metadataOp.getResult(2 + resultType.getRank() + it.index());
152  Value resultStrideVal =
153  arith::ConstantIndexOp::create(builder, loc, it.value());
154  Value isSameStride = arith::CmpIOp::create(
155  builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
156  cf::AssertOp::create(
157  builder, loc, isSameStride,
158  generateErrorMessage(op, "stride mismatch of dim " +
159  std::to_string(it.index())));
160  }
161  }
162 };
163 
164 struct CopyOpInterface
165  : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
166  CopyOp> {
167  void
168  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
169  function_ref<std::string(Operation *, StringRef)>
170  generateErrorMessage) const {
171  auto copyOp = cast<CopyOp>(op);
172  BaseMemRefType sourceType = copyOp.getSource().getType();
173  BaseMemRefType targetType = copyOp.getTarget().getType();
174  auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
175  auto rankedTargetType = dyn_cast<MemRefType>(targetType);
176 
177  // TODO: Verification for unranked memrefs is not supported yet.
178  if (!rankedSourceType || !rankedTargetType)
179  return;
180 
181  assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
182  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
183  // Fully static dimensions in both source and target operand are already
184  // verified by the op verifier.
185  if (!rankedSourceType.isDynamicDim(i) &&
186  !rankedTargetType.isDynamicDim(i))
187  continue;
188  auto getDimSize = [&](Value memRef, MemRefType type,
189  int64_t dim) -> Value {
190  return type.isDynamicDim(dim)
191  ? DimOp::create(builder, loc, memRef, dim).getResult()
192  : arith::ConstantIndexOp::create(builder, loc,
193  type.getDimSize(dim))
194  .getResult();
195  };
196  Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197  Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
198  Value sameDimSize = arith::CmpIOp::create(
199  builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200  cf::AssertOp::create(
201  builder, loc, sameDimSize,
202  generateErrorMessage(op, "size of " + std::to_string(i) +
203  "-th source/target dim does not match"));
204  }
205  }
206 };
207 
208 struct DimOpInterface
209  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
210  DimOp> {
211  void
212  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
213  function_ref<std::string(Operation *, StringRef)>
214  generateErrorMessage) const {
215  auto dimOp = cast<DimOp>(op);
216  Value rank = RankOp::create(builder, loc, dimOp.getSource());
217  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
218  cf::AssertOp::create(
219  builder, loc,
220  generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
221  generateErrorMessage(op, "index is out of bounds"));
222  }
223 };
224 
225 /// Verifies that the indices on load/store ops are in-bounds of the memref's
226 /// index space: 0 <= index#i < dim#i
227 template <typename LoadStoreOp>
228 struct LoadStoreOpInterface
229  : public RuntimeVerifiableOpInterface::ExternalModel<
230  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
231  void
232  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
233  function_ref<std::string(Operation *, StringRef)>
234  generateErrorMessage) const {
235  auto loadStoreOp = cast<LoadStoreOp>(op);
236 
237  auto memref = loadStoreOp.getMemref();
238  auto rank = memref.getType().getRank();
239  if (rank == 0) {
240  return;
241  }
242  auto indices = loadStoreOp.getIndices();
243 
244  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
245  Value assertCond;
246  for (auto i : llvm::seq<int64_t>(0, rank)) {
247  Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
248  Value inBounds =
249  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
250  assertCond =
251  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
252  : inBounds;
253  }
254  cf::AssertOp::create(builder, loc, assertCond,
255  generateErrorMessage(op, "out-of-bounds access"));
256  }
257 };
258 
259 struct SubViewOpInterface
260  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
261  SubViewOp> {
262  void
263  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
264  function_ref<std::string(Operation *, StringRef)>
265  generateErrorMessage) const {
266  auto subView = cast<SubViewOp>(op);
267  MemRefType sourceType = subView.getSource().getType();
268 
269  // For each dimension, assert that:
270  // 0 <= offset < dim_size
271  // 0 <= offset + (size - 1) * stride < dim_size
272  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
273  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
274  auto metadataOp =
275  ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
276  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
278  builder, loc, subView.getMixedOffsets()[i]);
279  Value size = getValueOrCreateConstantIndexOp(builder, loc,
280  subView.getMixedSizes()[i]);
282  builder, loc, subView.getMixedStrides()[i]);
283 
284  // Verify that offset is in-bounds.
285  Value dimSize = metadataOp.getSizes()[i];
286  Value offsetInBounds =
287  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
288  cf::AssertOp::create(builder, loc, offsetInBounds,
289  generateErrorMessage(op, "offset " +
290  std::to_string(i) +
291  " is out-of-bounds"));
292 
293  // Verify that slice does not run out-of-bounds.
294  Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
295  Value sizeMinusOneTimesStride =
296  arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
297  Value lastPos =
298  arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
299  Value lastPosInBounds =
300  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
301  cf::AssertOp::create(
302  builder, loc, lastPosInBounds,
303  generateErrorMessage(op,
304  "subview runs out-of-bounds along dimension " +
305  std::to_string(i)));
306  }
307  }
308 };
309 
310 struct ExpandShapeOpInterface
311  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
312  ExpandShapeOp> {
313  void
314  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
315  function_ref<std::string(Operation *, StringRef)>
316  generateErrorMessage) const {
317  auto expandShapeOp = cast<ExpandShapeOp>(op);
318 
319  // Verify that the expanded dim sizes are a product of the collapsed dim
320  // size.
321  for (const auto &it :
322  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
323  Value srcDimSz =
324  DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
325  int64_t groupSz = 1;
326  bool foundDynamicDim = false;
327  for (int64_t resultDim : it.value()) {
328  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
329  // Keep this assert here in case the op is extended in the future.
330  assert(!foundDynamicDim &&
331  "more than one dynamic dim found in reassoc group");
332  (void)foundDynamicDim;
333  foundDynamicDim = true;
334  continue;
335  }
336  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
337  }
338  Value staticResultDimSz =
339  arith::ConstantIndexOp::create(builder, loc, groupSz);
340  // staticResultDimSz must divide srcDimSz evenly.
341  Value mod =
342  arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
343  Value isModZero = arith::CmpIOp::create(
344  builder, loc, arith::CmpIPredicate::eq, mod,
345  arith::ConstantIndexOp::create(builder, loc, 0));
346  cf::AssertOp::create(
347  builder, loc, isModZero,
348  generateErrorMessage(op, "static result dims in reassoc group do not "
349  "divide src dim evenly"));
350  }
351  }
352 };
353 } // namespace
354 } // namespace memref
355 } // namespace mlir
356 
358  DialectRegistry &registry) {
359  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
360  AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
361  AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
362  CastOp::attachInterface<CastOpInterface>(*ctx);
363  CopyOp::attachInterface<CopyOpInterface>(*ctx);
364  DimOp::attachInterface<DimOpInterface>(*ctx);
365  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
366  GenericAtomicRMWOp::attachInterface<
367  LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
368  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
369  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
370  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
371  // Note: There is nothing to verify for ReinterpretCastOp.
372 
373  // Load additional dialects of which ops may get created.
374  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
375  cf::ControlFlowDialect>();
376  });
377 }
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
MLIRContext * getContext() const
Definition: Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...