29 loc, arith::CmpIPredicate::sge, value, lb);
31 loc, arith::CmpIPredicate::slt, value, ub);
33 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
37 struct AssumeAlignmentOpInterface
38 :
public RuntimeVerifiableOpInterface::ExternalModel<
39 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
43 generateErrorMessage)
const {
44 auto assumeOp = cast<AssumeAlignmentOp>(op);
45 Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
46 assumeOp.getMemref());
47 Value rest = arith::RemUIOp::create(
51 arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
54 builder, loc, isAligned,
55 generateErrorMessage(op,
"memref is not aligned to " +
56 std::to_string(assumeOp.getAlignment())));
60 struct CastOpInterface
61 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
66 generateErrorMessage)
const {
67 auto castOp = cast<CastOp>(op);
68 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
71 auto resultType = dyn_cast<MemRefType>(castOp.getType());
75 if (isa<UnrankedMemRefType>(srcType)) {
77 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
80 Value isSameRank = arith::CmpIOp::create(
81 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
82 cf::AssertOp::create(builder, loc, isSameRank,
83 generateErrorMessage(op,
"rank mismatch"));
90 int64_t dynamicOffset = ShapedType::kDynamic;
92 ShapedType::kDynamic);
94 dynamicOffset, dynamicShape);
97 stridedLayout, resultType.getMemorySpace());
99 CastOp::create(builder, loc, dynStridesType, castOp.getSource());
101 ExtractStridedMetadataOp::create(builder, loc, helperCast);
106 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
107 if (!rankedSrcType.isDynamicDim(it.index()))
111 if (resultType.isDynamicDim(it.index()))
115 DimOp::create(builder, loc, castOp.getSource(), it.index());
118 Value isSameSz = arith::CmpIOp::create(
119 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
120 cf::AssertOp::create(
121 builder, loc, isSameSz,
122 generateErrorMessage(op,
"size mismatch of dim " +
123 std::to_string(it.index())));
127 int64_t resultOffset;
129 if (
failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
133 if (resultOffset != ShapedType::kDynamic) {
135 Value srcOffset = metadataOp.getResult(1);
136 Value resultOffsetVal =
138 Value isSameOffset = arith::CmpIOp::create(
139 builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
140 cf::AssertOp::create(builder, loc, isSameOffset,
141 generateErrorMessage(op,
"offset mismatch"));
147 if (it.value() == ShapedType::kDynamic)
151 metadataOp.getResult(2 + resultType.getRank() + it.index());
152 Value resultStrideVal =
154 Value isSameStride = arith::CmpIOp::create(
155 builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
156 cf::AssertOp::create(
157 builder, loc, isSameStride,
158 generateErrorMessage(op,
"stride mismatch of dim " +
159 std::to_string(it.index())));
164 struct CopyOpInterface
165 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
170 generateErrorMessage)
const {
171 auto copyOp = cast<CopyOp>(op);
174 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
175 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
178 if (!rankedSourceType || !rankedTargetType)
181 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
182 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
185 if (!rankedSourceType.isDynamicDim(i) &&
186 !rankedTargetType.isDynamicDim(i))
188 auto getDimSize = [&](
Value memRef, MemRefType type,
189 int64_t dim) ->
Value {
190 return type.isDynamicDim(dim)
191 ? DimOp::create(builder, loc, memRef, dim).getResult()
193 type.getDimSize(dim))
196 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
198 Value sameDimSize = arith::CmpIOp::create(
199 builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200 cf::AssertOp::create(
201 builder, loc, sameDimSize,
202 generateErrorMessage(op,
"size of " + std::to_string(i) +
203 "-th source/target dim does not match"));
208 struct DimOpInterface
209 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
214 generateErrorMessage)
const {
215 auto dimOp = cast<DimOp>(op);
216 Value rank = RankOp::create(builder, loc, dimOp.getSource());
218 cf::AssertOp::create(
220 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
221 generateErrorMessage(op,
"index is out of bounds"));
227 template <
typename LoadStoreOp>
228 struct LoadStoreOpInterface
229 :
public RuntimeVerifiableOpInterface::ExternalModel<
230 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
234 generateErrorMessage)
const {
235 auto loadStoreOp = cast<LoadStoreOp>(op);
237 auto memref = loadStoreOp.getMemref();
238 auto rank = memref.
getType().getRank();
242 auto indices = loadStoreOp.getIndices();
246 for (
auto i : llvm::seq<int64_t>(0, rank)) {
249 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
251 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
254 cf::AssertOp::create(builder, loc, assertCond,
255 generateErrorMessage(op,
"out-of-bounds access"));
259 struct SubViewOpInterface
260 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
265 generateErrorMessage)
const {
266 auto subView = cast<SubViewOp>(op);
267 MemRefType sourceType = subView.getSource().getType();
275 ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
276 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
278 builder, loc, subView.getMixedOffsets()[i]);
280 subView.getMixedSizes()[i]);
282 builder, loc, subView.getMixedStrides()[i]);
285 Value dimSize = metadataOp.getSizes()[i];
286 Value offsetInBounds =
287 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
288 cf::AssertOp::create(builder, loc, offsetInBounds,
289 generateErrorMessage(op,
"offset " +
291 " is out-of-bounds"));
294 Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
295 Value sizeMinusOneTimesStride =
296 arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
298 arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
299 Value lastPosInBounds =
300 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
301 cf::AssertOp::create(
302 builder, loc, lastPosInBounds,
303 generateErrorMessage(op,
304 "subview runs out-of-bounds along dimension " +
310 struct ExpandShapeOpInterface
311 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
316 generateErrorMessage)
const {
317 auto expandShapeOp = cast<ExpandShapeOp>(op);
321 for (
const auto &it :
324 DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index());
326 bool foundDynamicDim =
false;
327 for (int64_t resultDim : it.value()) {
328 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
330 assert(!foundDynamicDim &&
331 "more than one dynamic dim found in reassoc group");
332 (void)foundDynamicDim;
333 foundDynamicDim =
true;
336 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
338 Value staticResultDimSz =
342 arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz);
343 Value isModZero = arith::CmpIOp::create(
344 builder, loc, arith::CmpIPredicate::eq, mod,
346 cf::AssertOp::create(
347 builder, loc, isModZero,
348 generateErrorMessage(op,
"static result dims in reassoc group do not "
349 "divide src dim evenly"));
360 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
361 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
362 CastOp::attachInterface<CastOpInterface>(*ctx);
363 CopyOp::attachInterface<CopyOpInterface>(*ctx);
364 DimOp::attachInterface<DimOpInterface>(*ctx);
365 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
366 GenericAtomicRMWOp::attachInterface<
367 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
368 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
369 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
370 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
374 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
375 cf::ControlFlowDialect>();
This class provides a shared interface for ranked and unranked memref types.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...