29 loc, arith::CmpIPredicate::sge, value, lb);
31 loc, arith::CmpIPredicate::slt, value, ub);
33 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
37 struct AssumeAlignmentOpInterface
38 :
public RuntimeVerifiableOpInterface::ExternalModel<
39 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
42 auto assumeOp = cast<AssumeAlignmentOp>(op);
43 Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
44 assumeOp.getMemref());
45 Value rest = arith::RemUIOp::create(
49 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
51 cf::AssertOp::create(builder, loc, isAligned,
52 RuntimeVerifiableOpInterface::generateErrorMessage(
53 op,
"memref is not aligned to " +
54 std::to_string(assumeOp.getAlignment())));
58 struct CastOpInterface
59 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
63 auto castOp = cast<CastOp>(op);
64 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
67 auto resultType = dyn_cast<MemRefType>(castOp.getType());
71 if (isa<UnrankedMemRefType>(srcType)) {
73 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
76 Value isSameRank = arith::CmpIOp::create(
77 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
78 cf::AssertOp::create(builder, loc, isSameRank,
79 RuntimeVerifiableOpInterface::generateErrorMessage(
80 op,
"rank mismatch"));
87 int64_t dynamicOffset = ShapedType::kDynamic;
89 ShapedType::kDynamic);
91 dynamicOffset, dynamicShape);
94 stridedLayout, resultType.getMemorySpace());
96 CastOp::create(builder, loc, dynStridesType, castOp.getSource());
98 ExtractStridedMetadataOp::create(builder, loc, helperCast);
103 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
104 if (!rankedSrcType.isDynamicDim(it.index()))
108 if (resultType.isDynamicDim(it.index()))
112 DimOp::create(builder, loc, castOp.getSource(), it.index());
115 Value isSameSz = arith::CmpIOp::create(
116 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117 cf::AssertOp::create(
118 builder, loc, isSameSz,
119 RuntimeVerifiableOpInterface::generateErrorMessage(
120 op,
"size mismatch of dim " + std::to_string(it.index())));
124 int64_t resultOffset;
126 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
130 if (resultOffset != ShapedType::kDynamic) {
132 Value srcOffset = metadataOp.getResult(1);
133 Value resultOffsetVal =
135 Value isSameOffset = arith::CmpIOp::create(
136 builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137 cf::AssertOp::create(builder, loc, isSameOffset,
138 RuntimeVerifiableOpInterface::generateErrorMessage(
139 op,
"offset mismatch"));
145 if (it.value() == ShapedType::kDynamic)
149 metadataOp.getResult(2 + resultType.getRank() + it.index());
150 Value resultStrideVal =
152 Value isSameStride = arith::CmpIOp::create(
153 builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154 cf::AssertOp::create(
155 builder, loc, isSameStride,
156 RuntimeVerifiableOpInterface::generateErrorMessage(
157 op,
"stride mismatch of dim " + std::to_string(it.index())));
162 struct CopyOpInterface
163 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
167 auto copyOp = cast<CopyOp>(op);
170 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
171 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
174 if (!rankedSourceType || !rankedTargetType)
177 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
178 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
181 if (!rankedSourceType.isDynamicDim(i) &&
182 !rankedTargetType.isDynamicDim(i))
184 auto getDimSize = [&](
Value memRef, MemRefType type,
185 int64_t dim) ->
Value {
186 return type.isDynamicDim(dim)
187 ? DimOp::create(builder, loc, memRef, dim).getResult()
189 type.getDimSize(dim))
192 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
193 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
194 Value sameDimSize = arith::CmpIOp::create(
195 builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196 cf::AssertOp::create(builder, loc, sameDimSize,
197 RuntimeVerifiableOpInterface::generateErrorMessage(
198 op,
"size of " + std::to_string(i) +
199 "-th source/target dim does not match"));
204 struct DimOpInterface
205 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
209 auto dimOp = cast<DimOp>(op);
210 Value rank = RankOp::create(builder, loc, dimOp.getSource());
212 cf::AssertOp::create(
214 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
215 RuntimeVerifiableOpInterface::generateErrorMessage(
216 op,
"index is out of bounds"));
222 template <
typename LoadStoreOp>
223 struct LoadStoreOpInterface
224 :
public RuntimeVerifiableOpInterface::ExternalModel<
225 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
228 auto loadStoreOp = cast<LoadStoreOp>(op);
230 auto memref = loadStoreOp.getMemref();
231 auto rank = memref.getType().getRank();
235 auto indices = loadStoreOp.getIndices();
239 for (
auto i : llvm::seq<int64_t>(0, rank)) {
242 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
244 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
247 cf::AssertOp::create(builder, loc, assertCond,
248 RuntimeVerifiableOpInterface::generateErrorMessage(
249 op,
"out-of-bounds access"));
253 struct SubViewOpInterface
254 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
258 auto subView = cast<SubViewOp>(op);
259 MemRefType sourceType = subView.getSource().getType();
267 ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
268 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
270 builder, loc, subView.getMixedOffsets()[i]);
272 subView.getMixedSizes()[i]);
274 builder, loc, subView.getMixedStrides()[i]);
277 Value dimSize = metadataOp.getSizes()[i];
278 Value offsetInBounds =
279 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
280 cf::AssertOp::create(
281 builder, loc, offsetInBounds,
282 RuntimeVerifiableOpInterface::generateErrorMessage(
283 op,
"offset " + std::to_string(i) +
" is out-of-bounds"));
286 Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
287 Value sizeMinusOneTimesStride =
288 arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
290 arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
291 Value lastPosInBounds =
292 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
293 cf::AssertOp::create(
294 builder, loc, lastPosInBounds,
295 RuntimeVerifiableOpInterface::generateErrorMessage(
296 op,
"subview runs out-of-bounds along dimension " +
302 struct ExpandShapeOpInterface
303 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
307 auto expandShapeOp = cast<ExpandShapeOp>(op);
311 for (
const auto &it :
314 DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
316 bool foundDynamicDim =
false;
317 for (int64_t resultDim : it.value()) {
318 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
320 assert(!foundDynamicDim &&
321 "more than one dynamic dim found in reassoc group");
322 (void)foundDynamicDim;
323 foundDynamicDim =
true;
326 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
328 Value staticResultDimSz =
332 arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
333 Value isModZero = arith::CmpIOp::create(
334 builder, loc, arith::CmpIPredicate::eq, mod,
336 cf::AssertOp::create(builder, loc, isModZero,
337 RuntimeVerifiableOpInterface::generateErrorMessage(
338 op,
"static result dims in reassoc group do not "
339 "divide src dim evenly"));
350 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
351 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
352 CastOp::attachInterface<CastOpInterface>(*ctx);
353 CopyOp::attachInterface<CopyOpInterface>(*ctx);
354 DimOp::attachInterface<DimOpInterface>(*ctx);
355 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
356 GenericAtomicRMWOp::attachInterface<
357 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
358 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
359 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
360 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
364 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
365 cf::ControlFlowDialect>();
This class provides a shared interface for ranked and unranked memref types.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...