27Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
30 loc, arith::CmpIPredicate::sge, value, lb);
32 loc, arith::CmpIPredicate::slt, value, ub);
34 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
38struct AssumeAlignmentOpInterface
39 :
public RuntimeVerifiableOpInterface::ExternalModel<
40 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
42 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
44 generateErrorMessage)
const {
45 auto assumeOp = cast<AssumeAlignmentOp>(op);
46 Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
47 assumeOp.getMemref());
48 Value rest = arith::RemUIOp::create(
52 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
55 builder, loc, isAligned,
56 generateErrorMessage(op,
"memref is not aligned to " +
57 std::to_string(assumeOp.getAlignment())));
62 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
65 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
67 generateErrorMessage)
const {
68 auto castOp = cast<CastOp>(op);
69 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
72 auto resultType = dyn_cast<MemRefType>(castOp.getType());
76 if (isa<UnrankedMemRefType>(srcType)) {
78 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
81 Value isSameRank = arith::CmpIOp::create(
82 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
83 cf::AssertOp::create(builder, loc, isSameRank,
84 generateErrorMessage(op,
"rank mismatch"));
91 int64_t dynamicOffset = ShapedType::kDynamic;
92 SmallVector<int64_t> dynamicShape(resultType.getRank(),
93 ShapedType::kDynamic);
94 auto stridedLayout = StridedLayoutAttr::get(builder.
getContext(),
95 dynamicOffset, dynamicShape);
97 MemRefType::get(dynamicShape, resultType.getElementType(),
98 stridedLayout, resultType.getMemorySpace());
100 CastOp::create(builder, loc, dynStridesType, castOp.getSource());
102 ExtractStridedMetadataOp::create(builder, loc, helperCast);
105 for (
const auto &it : llvm::enumerate(resultType.getShape())) {
107 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
108 if (!rankedSrcType.isDynamicDim(it.index()))
112 if (resultType.isDynamicDim(it.index()))
116 DimOp::create(builder, loc, castOp.getSource(), it.index());
119 Value isSameSz = arith::CmpIOp::create(
120 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
121 cf::AssertOp::create(
122 builder, loc, isSameSz,
123 generateErrorMessage(op,
"size mismatch of dim " +
124 std::to_string(it.index())));
128 int64_t resultOffset;
129 SmallVector<int64_t> resultStrides;
130 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
134 if (resultOffset != ShapedType::kDynamic) {
136 Value srcOffset = metadataOp.getResult(1);
137 Value resultOffsetVal =
139 Value isSameOffset = arith::CmpIOp::create(
140 builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
141 cf::AssertOp::create(builder, loc, isSameOffset,
142 generateErrorMessage(op,
"offset mismatch"));
146 for (
const auto &it : llvm::enumerate(resultStrides)) {
148 if (it.value() == ShapedType::kDynamic)
152 metadataOp.getResult(2 + resultType.getRank() + it.index());
153 Value resultStrideVal =
155 Value isSameStride = arith::CmpIOp::create(
156 builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157 cf::AssertOp::create(
158 builder, loc, isSameStride,
159 generateErrorMessage(op,
"stride mismatch of dim " +
160 std::to_string(it.index())));
165struct CopyOpInterface
166 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
169 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
171 generateErrorMessage)
const {
172 auto copyOp = cast<CopyOp>(op);
173 BaseMemRefType sourceType = copyOp.getSource().getType();
174 BaseMemRefType targetType = copyOp.getTarget().getType();
175 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
176 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
179 if (!rankedSourceType || !rankedTargetType)
182 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
183 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
186 if (!rankedSourceType.isDynamicDim(i) &&
187 !rankedTargetType.isDynamicDim(i))
189 auto getDimSize = [&](Value memRef, MemRefType type,
190 int64_t dim) -> Value {
191 return type.isDynamicDim(dim)
192 ? DimOp::create(builder, loc, memRef, dim).getResult()
194 type.getDimSize(dim))
197 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
198 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
199 Value sameDimSize = arith::CmpIOp::create(
200 builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
201 cf::AssertOp::create(
202 builder, loc, sameDimSize,
203 generateErrorMessage(op,
"size of " + std::to_string(i) +
204 "-th source/target dim does not match"));
210 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
213 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
215 generateErrorMessage)
const {
216 auto dimOp = cast<DimOp>(op);
217 Value rank = RankOp::create(builder, loc, dimOp.getSource());
219 cf::AssertOp::create(
221 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
222 generateErrorMessage(op,
"index is out of bounds"));
228template <
typename LoadStoreOp>
229struct LoadStoreOpInterface
230 :
public RuntimeVerifiableOpInterface::ExternalModel<
231 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
233 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
235 generateErrorMessage)
const {
236 auto loadStoreOp = cast<LoadStoreOp>(op);
238 auto memref = loadStoreOp.getMemref();
239 auto rank = memref.
getType().getRank();
243 auto indices = loadStoreOp.getIndices();
247 for (
auto i : llvm::seq<int64_t>(0, rank)) {
248 Value dimOp = builder.
createOrFold<memref::DimOp>(loc, memref, i);
250 generateInBoundsCheck(builder, loc,
indices[i], zero, dimOp);
252 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
255 cf::AssertOp::create(builder, loc, assertCond,
256 generateErrorMessage(op,
"out-of-bounds access"));
260struct SubViewOpInterface
261 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
264 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
266 generateErrorMessage)
const {
267 auto subView = cast<SubViewOp>(op);
268 MemRefType sourceType = subView.getSource().getType();
279 ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
281 for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
286 builder, loc, subView.getMixedOffsets()[i]);
288 subView.getMixedSizes()[i]);
290 builder, loc, subView.getMixedStrides()[i]);
291 Value dimSize = metadataOp.getSizes()[i];
294 Value sizeIsZero = arith::CmpIOp::create(
295 builder, loc, arith::CmpIPredicate::eq, size, zero);
296 auto offsetCheckIf = scf::IfOp::create(
297 builder, loc, sizeIsZero,
298 [&](OpBuilder &
b, Location loc) {
301 Value offsetGEZero = arith::CmpIOp::create(
302 b, loc, arith::CmpIPredicate::sge, offset, zero);
303 Value offsetLEDimSize = arith::CmpIOp::create(
304 b, loc, arith::CmpIPredicate::sle, offset, dimSize);
305 Value emptyOffsetValid =
306 arith::AndIOp::create(
b, loc, offsetGEZero, offsetLEDimSize);
307 scf::YieldOp::create(
b, loc, emptyOffsetValid);
309 [&](OpBuilder &
b, Location loc) {
312 Value offsetInBounds =
313 generateInBoundsCheck(
b, loc, offset, zero, dimSize);
314 scf::YieldOp::create(
b, loc, offsetInBounds);
317 Value offsetCondition = offsetCheckIf.getResult(0);
318 cf::AssertOp::create(builder, loc, offsetCondition,
319 generateErrorMessage(op,
"offset " +
321 " is out-of-bounds"));
325 Value sizeIsNonZero = arith::CmpIOp::create(
326 builder, loc, arith::CmpIPredicate::sgt, size, zero);
327 auto ifOp = scf::IfOp::create(
328 builder, loc, sizeIsNonZero,
329 [&](OpBuilder &
b, Location loc) {
331 Value sizeMinusOne = arith::SubIOp::create(
b, loc, size, one);
332 Value sizeMinusOneTimesStride =
333 arith::MulIOp::create(
b, loc, sizeMinusOne, stride);
335 arith::AddIOp::create(
b, loc, offset, sizeMinusOneTimesStride);
336 Value lastPosInBounds =
337 generateInBoundsCheck(
b, loc, lastPos, zero, dimSize);
338 scf::YieldOp::create(
b, loc, lastPosInBounds);
340 [&](OpBuilder &
b, Location loc) {
342 arith::ConstantOp::create(
b, loc,
b.getBoolAttr(
true));
343 scf::YieldOp::create(
b, loc, trueVal);
346 Value finalCondition = ifOp.getResult(0);
347 cf::AssertOp::create(
348 builder, loc, finalCondition,
349 generateErrorMessage(op,
350 "subview runs out-of-bounds along dimension " +
356struct ExpandShapeOpInterface
357 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
360 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
362 generateErrorMessage)
const {
363 auto expandShapeOp = cast<ExpandShapeOp>(op);
367 for (
const auto &it :
368 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
370 DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
372 bool foundDynamicDim =
false;
373 for (int64_t resultDim : it.value()) {
374 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
376 assert(!foundDynamicDim &&
377 "more than one dynamic dim found in reassoc group");
378 (void)foundDynamicDim;
379 foundDynamicDim =
true;
382 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
384 Value staticResultDimSz =
388 arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
389 Value isModZero = arith::CmpIOp::create(
390 builder, loc, arith::CmpIPredicate::eq, mod,
392 cf::AssertOp::create(
393 builder, loc, isModZero,
394 generateErrorMessage(op,
"static result dims in reassoc group do not "
395 "divide src dim evenly"));
406 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
407 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
408 CastOp::attachInterface<CastOpInterface>(*ctx);
409 CopyOp::attachInterface<CopyOpInterface>(*ctx);
410 DimOp::attachInterface<DimOpInterface>(*ctx);
411 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
412 GenericAtomicRMWOp::attachInterface<
413 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
414 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
415 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
416 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
420 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
421 cf::ControlFlowDialect>();
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref