MLIR 22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20
21using namespace mlir;
22
23namespace mlir {
24namespace memref {
25namespace {
26/// Generate a runtime check for lb <= value < ub.
27Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28 Value lb, Value ub) {
29 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
30 loc, arith::CmpIPredicate::sge, value, lb);
31 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
32 loc, arith::CmpIPredicate::slt, value, ub);
33 Value inBounds =
34 builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35 return inBounds;
36}
37
38struct AssumeAlignmentOpInterface
39 : public RuntimeVerifiableOpInterface::ExternalModel<
40 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
41 void
42 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
43 function_ref<std::string(Operation *, StringRef)>
44 generateErrorMessage) const {
45 auto assumeOp = cast<AssumeAlignmentOp>(op);
46 Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
47 assumeOp.getMemref());
48 Value rest = arith::RemUIOp::create(
49 builder, loc, ptr,
50 arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment()));
51 Value isAligned =
52 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
53 arith::ConstantIndexOp::create(builder, loc, 0));
54 cf::AssertOp::create(
55 builder, loc, isAligned,
56 generateErrorMessage(op, "memref is not aligned to " +
57 std::to_string(assumeOp.getAlignment())));
58 }
59};
60
61struct CastOpInterface
62 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
63 CastOp> {
64 void
65 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
66 function_ref<std::string(Operation *, StringRef)>
67 generateErrorMessage) const {
68 auto castOp = cast<CastOp>(op);
69 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
70
71 // Nothing to check if the result is an unranked memref.
72 auto resultType = dyn_cast<MemRefType>(castOp.getType());
73 if (!resultType)
74 return;
75
76 if (isa<UnrankedMemRefType>(srcType)) {
77 // Check rank.
78 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
79 Value resultRank =
80 arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
81 Value isSameRank = arith::CmpIOp::create(
82 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
83 cf::AssertOp::create(builder, loc, isSameRank,
84 generateErrorMessage(op, "rank mismatch"));
85 }
86
87 // Get source offset and strides. We do not have an op to get offsets and
88 // strides from unranked memrefs, so cast the source to a type with fully
89 // dynamic layout, from which we can then extract the offset and strides.
90 // (Rank was already verified.)
91 int64_t dynamicOffset = ShapedType::kDynamic;
92 SmallVector<int64_t> dynamicShape(resultType.getRank(),
93 ShapedType::kDynamic);
94 auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
95 dynamicOffset, dynamicShape);
96 auto dynStridesType =
97 MemRefType::get(dynamicShape, resultType.getElementType(),
98 stridedLayout, resultType.getMemorySpace());
99 Value helperCast =
100 CastOp::create(builder, loc, dynStridesType, castOp.getSource());
101 auto metadataOp =
102 ExtractStridedMetadataOp::create(builder, loc, helperCast);
103
104 // Check dimension sizes.
105 for (const auto &it : llvm::enumerate(resultType.getShape())) {
106 // Static dim size -> static/dynamic dim size does not need verification.
107 if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
108 if (!rankedSrcType.isDynamicDim(it.index()))
109 continue;
110
111 // Static/dynamic dim size -> dynamic dim size does not need verification.
112 if (resultType.isDynamicDim(it.index()))
113 continue;
114
115 Value srcDimSz =
116 DimOp::create(builder, loc, castOp.getSource(), it.index());
117 Value resultDimSz =
118 arith::ConstantIndexOp::create(builder, loc, it.value());
119 Value isSameSz = arith::CmpIOp::create(
120 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
121 cf::AssertOp::create(
122 builder, loc, isSameSz,
123 generateErrorMessage(op, "size mismatch of dim " +
124 std::to_string(it.index())));
125 }
126
127 // Get result offset and strides.
128 int64_t resultOffset;
129 SmallVector<int64_t> resultStrides;
130 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
131 return;
132
133 // Check offset.
134 if (resultOffset != ShapedType::kDynamic) {
135 // Static/dynamic offset -> dynamic offset does not need verification.
136 Value srcOffset = metadataOp.getResult(1);
137 Value resultOffsetVal =
138 arith::ConstantIndexOp::create(builder, loc, resultOffset);
139 Value isSameOffset = arith::CmpIOp::create(
140 builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
141 cf::AssertOp::create(builder, loc, isSameOffset,
142 generateErrorMessage(op, "offset mismatch"));
143 }
144
145 // Check strides.
146 for (const auto &it : llvm::enumerate(resultStrides)) {
147 // Static/dynamic stride -> dynamic stride does not need verification.
148 if (it.value() == ShapedType::kDynamic)
149 continue;
150
151 Value srcStride =
152 metadataOp.getResult(2 + resultType.getRank() + it.index());
153 Value resultStrideVal =
154 arith::ConstantIndexOp::create(builder, loc, it.value());
155 Value isSameStride = arith::CmpIOp::create(
156 builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157 cf::AssertOp::create(
158 builder, loc, isSameStride,
159 generateErrorMessage(op, "stride mismatch of dim " +
160 std::to_string(it.index())));
161 }
162 }
163};
164
165struct CopyOpInterface
166 : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
167 CopyOp> {
168 void
169 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
170 function_ref<std::string(Operation *, StringRef)>
171 generateErrorMessage) const {
172 auto copyOp = cast<CopyOp>(op);
173 BaseMemRefType sourceType = copyOp.getSource().getType();
174 BaseMemRefType targetType = copyOp.getTarget().getType();
175 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
176 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
177
178 // TODO: Verification for unranked memrefs is not supported yet.
179 if (!rankedSourceType || !rankedTargetType)
180 return;
181
182 assert(sourceType.getRank() == targetType.getRank() && "rank mismatch");
183 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
184 // Fully static dimensions in both source and target operand are already
185 // verified by the op verifier.
186 if (!rankedSourceType.isDynamicDim(i) &&
187 !rankedTargetType.isDynamicDim(i))
188 continue;
189 auto getDimSize = [&](Value memRef, MemRefType type,
190 int64_t dim) -> Value {
191 return type.isDynamicDim(dim)
192 ? DimOp::create(builder, loc, memRef, dim).getResult()
193 : arith::ConstantIndexOp::create(builder, loc,
194 type.getDimSize(dim))
195 .getResult();
196 };
197 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
198 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
199 Value sameDimSize = arith::CmpIOp::create(
200 builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
201 cf::AssertOp::create(
202 builder, loc, sameDimSize,
203 generateErrorMessage(op, "size of " + std::to_string(i) +
204 "-th source/target dim does not match"));
205 }
206 }
207};
208
209struct DimOpInterface
210 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
211 DimOp> {
212 void
213 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
214 function_ref<std::string(Operation *, StringRef)>
215 generateErrorMessage) const {
216 auto dimOp = cast<DimOp>(op);
217 Value rank = RankOp::create(builder, loc, dimOp.getSource());
218 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
219 cf::AssertOp::create(
220 builder, loc,
221 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
222 generateErrorMessage(op, "index is out of bounds"));
223 }
224};
225
226/// Verifies that the indices on load/store ops are in-bounds of the memref's
227/// index space: 0 <= index#i < dim#i
228template <typename LoadStoreOp>
229struct LoadStoreOpInterface
230 : public RuntimeVerifiableOpInterface::ExternalModel<
231 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
232 void
233 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
234 function_ref<std::string(Operation *, StringRef)>
235 generateErrorMessage) const {
236 auto loadStoreOp = cast<LoadStoreOp>(op);
237
238 auto memref = loadStoreOp.getMemref();
239 auto rank = memref.getType().getRank();
240 if (rank == 0) {
241 return;
242 }
243 auto indices = loadStoreOp.getIndices();
244
245 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
246 Value assertCond;
247 for (auto i : llvm::seq<int64_t>(0, rank)) {
248 Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
249 Value inBounds =
250 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
251 assertCond =
252 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
253 : inBounds;
254 }
255 cf::AssertOp::create(builder, loc, assertCond,
256 generateErrorMessage(op, "out-of-bounds access"));
257 }
258};
259
260struct SubViewOpInterface
261 : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
262 SubViewOp> {
263 void
264 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
265 function_ref<std::string(Operation *, StringRef)>
266 generateErrorMessage) const {
267 auto subView = cast<SubViewOp>(op);
268 MemRefType sourceType = subView.getSource().getType();
269
270 // For each dimension, assert that:
271 // For empty slices (size == 0) : 0 <= offset <= dim_size
272 // For non-empty slices (size > 0): 0 <= offset < dim_size
273 // 0 <= offset + (size - 1) * stride
274 // dim_size
275 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
276 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
277
278 auto metadataOp =
279 ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
280
281 for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
282 // Reset insertion point to before the operation for each dimension.
283 builder.setInsertionPoint(subView);
284
285 Value offset = getValueOrCreateConstantIndexOp(
286 builder, loc, subView.getMixedOffsets()[i]);
287 Value size = getValueOrCreateConstantIndexOp(builder, loc,
288 subView.getMixedSizes()[i]);
289 Value stride = getValueOrCreateConstantIndexOp(
290 builder, loc, subView.getMixedStrides()[i]);
291 Value dimSize = metadataOp.getSizes()[i];
292
293 // Verify that offset is in-bounds (conditional on slice size).
294 Value sizeIsZero = arith::CmpIOp::create(
295 builder, loc, arith::CmpIPredicate::eq, size, zero);
296 auto offsetCheckIf = scf::IfOp::create(
297 builder, loc, sizeIsZero,
298 [&](OpBuilder &b, Location loc) {
299 // For empty slices, offset can be at the boundary: 0 <= offset <=
300 // dimSize.
301 Value offsetGEZero = arith::CmpIOp::create(
302 b, loc, arith::CmpIPredicate::sge, offset, zero);
303 Value offsetLEDimSize = arith::CmpIOp::create(
304 b, loc, arith::CmpIPredicate::sle, offset, dimSize);
305 Value emptyOffsetValid =
306 arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
307 scf::YieldOp::create(b, loc, emptyOffsetValid);
308 },
309 [&](OpBuilder &b, Location loc) {
310 // For non-empty slices, offset must be a valid index: 0 <= offset
311 // dimSize.
312 Value offsetInBounds =
313 generateInBoundsCheck(b, loc, offset, zero, dimSize);
314 scf::YieldOp::create(b, loc, offsetInBounds);
315 });
316
317 Value offsetCondition = offsetCheckIf.getResult(0);
318 cf::AssertOp::create(builder, loc, offsetCondition,
319 generateErrorMessage(op, "offset " +
320 std::to_string(i) +
321 " is out-of-bounds"));
322
323 // Verify that the slice endpoint is in-bounds (only for non-empty
324 // slices).
325 Value sizeIsNonZero = arith::CmpIOp::create(
326 builder, loc, arith::CmpIPredicate::sgt, size, zero);
327 auto ifOp = scf::IfOp::create(
328 builder, loc, sizeIsNonZero,
329 [&](OpBuilder &b, Location loc) {
330 // Verify that slice does not run out-of-bounds.
331 Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
332 Value sizeMinusOneTimesStride =
333 arith::MulIOp::create(b, loc, sizeMinusOne, stride);
334 Value lastPos =
335 arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
336 Value lastPosInBounds =
337 generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
338 scf::YieldOp::create(b, loc, lastPosInBounds);
339 },
340 [&](OpBuilder &b, Location loc) {
341 Value trueVal =
342 arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
343 scf::YieldOp::create(b, loc, trueVal);
344 });
345
346 Value finalCondition = ifOp.getResult(0);
347 cf::AssertOp::create(
348 builder, loc, finalCondition,
349 generateErrorMessage(op,
350 "subview runs out-of-bounds along dimension " +
351 std::to_string(i)));
352 }
353 }
354};
355
356struct ExpandShapeOpInterface
357 : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
358 ExpandShapeOp> {
359 void
360 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
361 function_ref<std::string(Operation *, StringRef)>
362 generateErrorMessage) const {
363 auto expandShapeOp = cast<ExpandShapeOp>(op);
364
365 // Verify that the expanded dim sizes are a product of the collapsed dim
366 // size.
367 for (const auto &it :
368 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
369 Value srcDimSz =
370 DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
371 int64_t groupSz = 1;
372 bool foundDynamicDim = false;
373 for (int64_t resultDim : it.value()) {
374 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
375 // Keep this assert here in case the op is extended in the future.
376 assert(!foundDynamicDim &&
377 "more than one dynamic dim found in reassoc group");
378 (void)foundDynamicDim;
379 foundDynamicDim = true;
380 continue;
381 }
382 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
383 }
384 Value staticResultDimSz =
385 arith::ConstantIndexOp::create(builder, loc, groupSz);
386 // staticResultDimSz must divide srcDimSz evenly.
387 Value mod =
388 arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
389 Value isModZero = arith::CmpIOp::create(
390 builder, loc, arith::CmpIPredicate::eq, mod,
391 arith::ConstantIndexOp::create(builder, loc, 0));
392 cf::AssertOp::create(
393 builder, loc, isModZero,
394 generateErrorMessage(op, "static result dims in reassoc group do not "
395 "divide src dim evenly"));
396 }
397 }
398};
399} // namespace
400} // namespace memref
401} // namespace mlir
402
404 DialectRegistry &registry) {
405 registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
406 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
407 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
408 CastOp::attachInterface<CastOpInterface>(*ctx);
409 CopyOp::attachInterface<CopyOpInterface>(*ctx);
410 DimOp::attachInterface<DimOpInterface>(*ctx);
411 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
412 GenericAtomicRMWOp::attachInterface<
413 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
414 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
415 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
416 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
417 // Note: There is nothing to verify for ReinterpretCastOp.
418
419 // Load additional dialects of which ops may get created.
420 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
421 cf::ControlFlowDialect>();
422 });
423}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext * getContext() const
Definition Builders.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152