MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 
21 using namespace mlir;
22 
23 namespace mlir {
24 namespace memref {
25 namespace {
26 /// Generate a runtime check for lb <= value < ub.
27 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28  Value lb, Value ub) {
29  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
30  loc, arith::CmpIPredicate::sge, value, lb);
31  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
32  loc, arith::CmpIPredicate::slt, value, ub);
33  Value inBounds =
34  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35  return inBounds;
36 }
37 
38 struct AssumeAlignmentOpInterface
39  : public RuntimeVerifiableOpInterface::ExternalModel<
40  AssumeAlignmentOpInterface, AssumeAlignmentOp> {
41  void
42  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
43  function_ref<std::string(Operation *, StringRef)>
44  generateErrorMessage) const {
45  auto assumeOp = cast<AssumeAlignmentOp>(op);
46  Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
47  assumeOp.getMemref());
48  Value rest = arith::RemUIOp::create(
49  builder, loc, ptr,
50  arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
51  Value isAligned =
52  arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
53  arith::ConstantIndexOp::create(builder, loc, 0));
54  cf::AssertOp::create(
55  builder, loc, isAligned,
56  generateErrorMessage(op, "memref is not aligned to " +
57  std::to_string(assumeOp.getAlignment())));
58  }
59 };
60 
61 struct CastOpInterface
62  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
63  CastOp> {
64  void
65  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
66  function_ref<std::string(Operation *, StringRef)>
67  generateErrorMessage) const {
68  auto castOp = cast<CastOp>(op);
69  auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
70 
71  // Nothing to check if the result is an unranked memref.
72  auto resultType = dyn_cast<MemRefType>(castOp.getType());
73  if (!resultType)
74  return;
75 
76  if (isa<UnrankedMemRefType>(srcType)) {
77  // Check rank.
78  Value srcRank = RankOp::create(builder, loc, castOp.getSource());
79  Value resultRank =
80  arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
81  Value isSameRank = arith::CmpIOp::create(
82  builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
83  cf::AssertOp::create(builder, loc, isSameRank,
84  generateErrorMessage(op, "rank mismatch"));
85  }
86 
87  // Get source offset and strides. We do not have an op to get offsets and
88  // strides from unranked memrefs, so cast the source to a type with fully
89  // dynamic layout, from which we can then extract the offset and strides.
90  // (Rank was already verified.)
91  int64_t dynamicOffset = ShapedType::kDynamic;
92  SmallVector<int64_t> dynamicShape(resultType.getRank(),
93  ShapedType::kDynamic);
94  auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
95  dynamicOffset, dynamicShape);
96  auto dynStridesType =
97  MemRefType::get(dynamicShape, resultType.getElementType(),
98  stridedLayout, resultType.getMemorySpace());
99  Value helperCast =
100  CastOp::create(builder, loc, dynStridesType, castOp.getSource());
101  auto metadataOp =
102  ExtractStridedMetadataOp::create(builder, loc, helperCast);
103 
104  // Check dimension sizes.
105  for (const auto &it : llvm::enumerate(resultType.getShape())) {
106  // Static dim size -> static/dynamic dim size does not need verification.
107  if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
108  if (!rankedSrcType.isDynamicDim(it.index()))
109  continue;
110 
111  // Static/dynamic dim size -> dynamic dim size does not need verification.
112  if (resultType.isDynamicDim(it.index()))
113  continue;
114 
115  Value srcDimSz =
116  DimOp::create(builder, loc, castOp.getSource(), it.index());
117  Value resultDimSz =
118  arith::ConstantIndexOp::create(builder, loc, it.value());
119  Value isSameSz = arith::CmpIOp::create(
120  builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
121  cf::AssertOp::create(
122  builder, loc, isSameSz,
123  generateErrorMessage(op, "size mismatch of dim " +
124  std::to_string(it.index())));
125  }
126 
127  // Get result offset and strides.
128  int64_t resultOffset;
129  SmallVector<int64_t> resultStrides;
130  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
131  return;
132 
133  // Check offset.
134  if (resultOffset != ShapedType::kDynamic) {
135  // Static/dynamic offset -> dynamic offset does not need verification.
136  Value srcOffset = metadataOp.getResult(1);
137  Value resultOffsetVal =
138  arith::ConstantIndexOp::create(builder, loc, resultOffset);
139  Value isSameOffset = arith::CmpIOp::create(
140  builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
141  cf::AssertOp::create(builder, loc, isSameOffset,
142  generateErrorMessage(op, "offset mismatch"));
143  }
144 
145  // Check strides.
146  for (const auto &it : llvm::enumerate(resultStrides)) {
147  // Static/dynamic stride -> dynamic stride does not need verification.
148  if (it.value() == ShapedType::kDynamic)
149  continue;
150 
151  Value srcStride =
152  metadataOp.getResult(2 + resultType.getRank() + it.index());
153  Value resultStrideVal =
154  arith::ConstantIndexOp::create(builder, loc, it.value());
155  Value isSameStride = arith::CmpIOp::create(
156  builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157  cf::AssertOp::create(
158  builder, loc, isSameStride,
159  generateErrorMessage(op, "stride mismatch of dim " +
160  std::to_string(it.index())));
161  }
162  }
163 };
164 
165 struct CopyOpInterface
166  : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
167  CopyOp> {
168  void
169  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
170  function_ref<std::string(Operation *, StringRef)>
171  generateErrorMessage) const {
172  auto copyOp = cast<CopyOp>(op);
173  BaseMemRefType sourceType = copyOp.getSource().getType();
174  BaseMemRefType targetType = copyOp.getTarget().getType();
175  auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
176  auto rankedTargetType = dyn_cast<MemRefType>(targetType);
177 
178  // TODO: Verification for unranked memrefs is not supported yet.
179  if (!rankedSourceType || !rankedTargetType)
180  return;
181 
182  assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
183  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
184  // Fully static dimensions in both source and target operand are already
185  // verified by the op verifier.
186  if (!rankedSourceType.isDynamicDim(i) &&
187  !rankedTargetType.isDynamicDim(i))
188  continue;
189  auto getDimSize = [&](Value memRef, MemRefType type,
190  int64_t dim) -> Value {
191  return type.isDynamicDim(dim)
192  ? DimOp::create(builder, loc, memRef, dim).getResult()
193  : arith::ConstantIndexOp::create(builder, loc,
194  type.getDimSize(dim))
195  .getResult();
196  };
197  Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
198  Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
199  Value sameDimSize = arith::CmpIOp::create(
200  builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
201  cf::AssertOp::create(
202  builder, loc, sameDimSize,
203  generateErrorMessage(op, "size of " + std::to_string(i) +
204  "-th source/target dim does not match"));
205  }
206  }
207 };
208 
209 struct DimOpInterface
210  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
211  DimOp> {
212  void
213  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
214  function_ref<std::string(Operation *, StringRef)>
215  generateErrorMessage) const {
216  auto dimOp = cast<DimOp>(op);
217  Value rank = RankOp::create(builder, loc, dimOp.getSource());
218  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
219  cf::AssertOp::create(
220  builder, loc,
221  generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
222  generateErrorMessage(op, "index is out of bounds"));
223  }
224 };
225 
226 /// Verifies that the indices on load/store ops are in-bounds of the memref's
227 /// index space: 0 <= index#i < dim#i
228 template <typename LoadStoreOp>
229 struct LoadStoreOpInterface
230  : public RuntimeVerifiableOpInterface::ExternalModel<
231  LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
232  void
233  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
234  function_ref<std::string(Operation *, StringRef)>
235  generateErrorMessage) const {
236  auto loadStoreOp = cast<LoadStoreOp>(op);
237 
238  auto memref = loadStoreOp.getMemref();
239  auto rank = memref.getType().getRank();
240  if (rank == 0) {
241  return;
242  }
243  auto indices = loadStoreOp.getIndices();
244 
245  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
246  Value assertCond;
247  for (auto i : llvm::seq<int64_t>(0, rank)) {
248  Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
249  Value inBounds =
250  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
251  assertCond =
252  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
253  : inBounds;
254  }
255  cf::AssertOp::create(builder, loc, assertCond,
256  generateErrorMessage(op, "out-of-bounds access"));
257  }
258 };
259 
260 struct SubViewOpInterface
261  : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
262  SubViewOp> {
263  void
264  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
265  function_ref<std::string(Operation *, StringRef)>
266  generateErrorMessage) const {
267  auto subView = cast<SubViewOp>(op);
268  MemRefType sourceType = subView.getSource().getType();
269 
270  // For each dimension, assert that:
271  // 0 <= offset < dim_size
272  // 0 <= offset + (size - 1) * stride < dim_size
273  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
274  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
275  auto metadataOp =
276  ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
277  for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
278  // Reset insertion point to before the operation for each dimension
279  builder.setInsertionPoint(subView);
281  builder, loc, subView.getMixedOffsets()[i]);
282  Value size = getValueOrCreateConstantIndexOp(builder, loc,
283  subView.getMixedSizes()[i]);
285  builder, loc, subView.getMixedStrides()[i]);
286 
287  // Verify that offset is in-bounds.
288  Value dimSize = metadataOp.getSizes()[i];
289  Value offsetInBounds =
290  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
291  cf::AssertOp::create(builder, loc, offsetInBounds,
292  generateErrorMessage(op, "offset " +
293  std::to_string(i) +
294  " is out-of-bounds"));
295 
296  // Only verify if size > 0
297  Value sizeIsNonZero = arith::CmpIOp::create(
298  builder, loc, arith::CmpIPredicate::sgt, size, zero);
299 
300  auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
301  sizeIsNonZero, /*withElseRegion=*/true);
302 
303  // Populate the "then" region (for size > 0).
304  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
305 
306  // Verify that slice does not run out-of-bounds.
307  Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
308  Value sizeMinusOneTimesStride =
309  arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
310  Value lastPos =
311  arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
312  Value lastPosInBounds =
313  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
314 
315  scf::YieldOp::create(builder, loc, lastPosInBounds);
316 
317  // Populate the "else" region (for size == 0).
318  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
319  Value trueVal =
320  arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
321  scf::YieldOp::create(builder, loc, trueVal);
322 
323  builder.setInsertionPointAfter(ifOp);
324  Value finalCondition = ifOp.getResult(0);
325 
326  cf::AssertOp::create(
327  builder, loc, finalCondition,
328  generateErrorMessage(op,
329  "subview runs out-of-bounds along dimension " +
330  std::to_string(i)));
331  }
332  }
333 };
334 
335 struct ExpandShapeOpInterface
336  : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
337  ExpandShapeOp> {
338  void
339  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
340  function_ref<std::string(Operation *, StringRef)>
341  generateErrorMessage) const {
342  auto expandShapeOp = cast<ExpandShapeOp>(op);
343 
344  // Verify that the expanded dim sizes are a product of the collapsed dim
345  // size.
346  for (const auto &it :
347  llvm::enumerate(expandShapeOp.getReassociationIndices())) {
348  Value srcDimSz =
349  DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
350  int64_t groupSz = 1;
351  bool foundDynamicDim = false;
352  for (int64_t resultDim : it.value()) {
353  if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
354  // Keep this assert here in case the op is extended in the future.
355  assert(!foundDynamicDim &&
356  "more than one dynamic dim found in reassoc group");
357  (void)foundDynamicDim;
358  foundDynamicDim = true;
359  continue;
360  }
361  groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
362  }
363  Value staticResultDimSz =
364  arith::ConstantIndexOp::create(builder, loc, groupSz);
365  // staticResultDimSz must divide srcDimSz evenly.
366  Value mod =
367  arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
368  Value isModZero = arith::CmpIOp::create(
369  builder, loc, arith::CmpIPredicate::eq, mod,
370  arith::ConstantIndexOp::create(builder, loc, 0));
371  cf::AssertOp::create(
372  builder, loc, isModZero,
373  generateErrorMessage(op, "static result dims in reassoc group do not "
374  "divide src dim evenly"));
375  }
376  }
377 };
378 } // namespace
379 } // namespace memref
380 } // namespace mlir
381 
383  DialectRegistry &registry) {
384  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
385  AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
386  AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
387  CastOp::attachInterface<CastOpInterface>(*ctx);
388  CopyOp::attachInterface<CopyOpInterface>(*ctx);
389  DimOp::attachInterface<DimOpInterface>(*ctx);
390  ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
391  GenericAtomicRMWOp::attachInterface<
392  LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
393  LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
394  StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
395  SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
396  // Note: There is nothing to verify for ReinterpretCastOp.
397 
398  // Load additional dialects of which ops may get created.
399  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
400  cf::ControlFlowDialect>();
401  });
402 }
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...