30 loc, arith::CmpIPredicate::sge, value, lb);
32 loc, arith::CmpIPredicate::slt, value, ub);
34 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
38 struct AssumeAlignmentOpInterface
39 :
public RuntimeVerifiableOpInterface::ExternalModel<
40 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
43 auto assumeOp = cast<AssumeAlignmentOp>(op);
44 Value ptr = builder.
create<ExtractAlignedPointerAsIndexOp>(
45 loc, assumeOp.getMemref());
48 builder.
create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
50 loc, arith::CmpIPredicate::eq, rest,
51 builder.
create<arith::ConstantIndexOp>(loc, 0));
52 builder.
create<cf::AssertOp>(
54 RuntimeVerifiableOpInterface::generateErrorMessage(
55 op,
"memref is not aligned to " +
56 std::to_string(assumeOp.getAlignment())));
60 struct CastOpInterface
61 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
65 auto castOp = cast<CastOp>(op);
66 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
69 auto resultType = dyn_cast<MemRefType>(castOp.getType());
73 if (isa<UnrankedMemRefType>(srcType)) {
75 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
77 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
79 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
80 builder.
create<cf::AssertOp>(
82 RuntimeVerifiableOpInterface::generateErrorMessage(op,
90 int64_t dynamicOffset = ShapedType::kDynamic;
92 ShapedType::kDynamic);
94 dynamicOffset, dynamicShape);
97 stridedLayout, resultType.getMemorySpace());
99 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
100 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
105 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
106 if (!rankedSrcType.isDynamicDim(it.index()))
110 if (resultType.isDynamicDim(it.index()))
114 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
116 builder.
create<arith::ConstantIndexOp>(loc, it.value());
118 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
119 builder.
create<cf::AssertOp>(
121 RuntimeVerifiableOpInterface::generateErrorMessage(
122 op,
"size mismatch of dim " + std::to_string(it.index())));
126 int64_t resultOffset;
128 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
132 if (resultOffset != ShapedType::kDynamic) {
134 Value srcOffset = metadataOp.getResult(1);
135 Value resultOffsetVal =
136 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
137 Value isSameOffset = builder.
create<arith::CmpIOp>(
138 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
139 builder.
create<cf::AssertOp>(
141 RuntimeVerifiableOpInterface::generateErrorMessage(
142 op,
"offset mismatch"));
148 if (it.value() == ShapedType::kDynamic)
152 metadataOp.getResult(2 + resultType.getRank() + it.index());
153 Value resultStrideVal =
154 builder.
create<arith::ConstantIndexOp>(loc, it.value());
155 Value isSameStride = builder.
create<arith::CmpIOp>(
156 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157 builder.
create<cf::AssertOp>(
159 RuntimeVerifiableOpInterface::generateErrorMessage(
160 op,
"stride mismatch of dim " + std::to_string(it.index())));
165 struct CopyOpInterface
166 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
170 auto copyOp = cast<CopyOp>(op);
173 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
174 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
177 if (!rankedSourceType || !rankedTargetType)
180 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
181 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
184 if (!rankedSourceType.isDynamicDim(i) &&
185 !rankedTargetType.isDynamicDim(i))
187 auto getDimSize = [&](
Value memRef, MemRefType type,
188 int64_t dim) ->
Value {
189 return type.isDynamicDim(dim)
190 ? builder.
create<DimOp>(loc, memRef, dim).getResult()
192 .create<arith::ConstantIndexOp>(loc,
193 type.getDimSize(dim))
196 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
199 loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200 builder.
create<cf::AssertOp>(
202 RuntimeVerifiableOpInterface::generateErrorMessage(
203 op,
"size of " + std::to_string(i) +
204 "-th source/target dim does not match"));
209 struct DimOpInterface
210 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
214 auto dimOp = cast<DimOp>(op);
215 Value rank = builder.
create<RankOp>(loc, dimOp.getSource());
216 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
217 builder.
create<cf::AssertOp>(
218 loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
219 RuntimeVerifiableOpInterface::generateErrorMessage(
220 op,
"index is out of bounds"));
226 template <
typename LoadStoreOp>
227 struct LoadStoreOpInterface
228 :
public RuntimeVerifiableOpInterface::ExternalModel<
229 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
232 auto loadStoreOp = cast<LoadStoreOp>(op);
234 auto memref = loadStoreOp.getMemref();
235 auto rank = memref.getType().getRank();
239 auto indices = loadStoreOp.getIndices();
241 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
243 for (
auto i : llvm::seq<int64_t>(0, rank)) {
246 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
248 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
251 builder.
create<cf::AssertOp>(
253 RuntimeVerifiableOpInterface::generateErrorMessage(
254 op,
"out-of-bounds access"));
258 struct SubViewOpInterface
259 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
263 auto subView = cast<SubViewOp>(op);
264 MemRefType sourceType = subView.getSource().getType();
269 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
270 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
272 builder.
create<ExtractStridedMetadataOp>(loc, subView.getSource());
273 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
275 builder, loc, subView.getMixedOffsets()[i]);
277 subView.getMixedSizes()[i]);
279 builder, loc, subView.getMixedStrides()[i]);
282 Value dimSize = metadataOp.getSizes()[i];
283 Value offsetInBounds =
284 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
285 builder.
create<cf::AssertOp>(
287 RuntimeVerifiableOpInterface::generateErrorMessage(
288 op,
"offset " + std::to_string(i) +
" is out-of-bounds"));
291 Value sizeMinusOne = builder.
create<arith::SubIOp>(loc, size, one);
292 Value sizeMinusOneTimesStride =
293 builder.
create<arith::MulIOp>(loc, sizeMinusOne, stride);
295 builder.
create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
296 Value lastPosInBounds =
297 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
298 builder.
create<cf::AssertOp>(
299 loc, lastPosInBounds,
300 RuntimeVerifiableOpInterface::generateErrorMessage(
301 op,
"subview runs out-of-bounds along dimension " +
307 struct ExpandShapeOpInterface
308 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
312 auto expandShapeOp = cast<ExpandShapeOp>(op);
316 for (
const auto &it :
319 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
321 bool foundDynamicDim =
false;
322 for (int64_t resultDim : it.value()) {
323 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
325 assert(!foundDynamicDim &&
326 "more than one dynamic dim found in reassoc group");
327 (void)foundDynamicDim;
328 foundDynamicDim =
true;
331 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
333 Value staticResultDimSz =
334 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
337 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
339 loc, arith::CmpIPredicate::eq, mod,
340 builder.
create<arith::ConstantIndexOp>(loc, 0));
341 builder.
create<cf::AssertOp>(
343 RuntimeVerifiableOpInterface::generateErrorMessage(
344 op,
"static result dims in reassoc group do not "
345 "divide src dim evenly"));
356 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
357 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
358 CastOp::attachInterface<CastOpInterface>(*ctx);
359 CopyOp::attachInterface<CopyOpInterface>(*ctx);
360 DimOp::attachInterface<DimOpInterface>(*ctx);
361 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
362 GenericAtomicRMWOp::attachInterface<
363 LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
364 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
365 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
366 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
370 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
371 cf::ControlFlowDialect>();
This class provides a shared interface for ranked and unranked memref types.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...