30 loc, arith::CmpIPredicate::sge, value, lb);
32 loc, arith::CmpIPredicate::slt, value, ub);
34 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
38 struct AssumeAlignmentOpInterface
39 :
public RuntimeVerifiableOpInterface::ExternalModel<
40 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
43 auto assumeOp = cast<AssumeAlignmentOp>(op);
44 Value ptr = builder.
create<ExtractAlignedPointerAsIndexOp>(
45 loc, assumeOp.getMemref());
48 builder.
create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
50 loc, arith::CmpIPredicate::eq, rest,
51 builder.
create<arith::ConstantIndexOp>(loc, 0));
52 builder.
create<cf::AssertOp>(
54 RuntimeVerifiableOpInterface::generateErrorMessage(
55 op,
"memref is not aligned to " +
56 std::to_string(assumeOp.getAlignment())));
60 struct CastOpInterface
61 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
65 auto castOp = cast<CastOp>(op);
66 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
69 auto resultType = dyn_cast<MemRefType>(castOp.getType());
73 if (isa<UnrankedMemRefType>(srcType)) {
75 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
77 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
79 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
80 builder.
create<cf::AssertOp>(
82 RuntimeVerifiableOpInterface::generateErrorMessage(op,
90 int64_t dynamicOffset = ShapedType::kDynamic;
92 ShapedType::kDynamic);
94 dynamicOffset, dynamicShape);
97 stridedLayout, resultType.getMemorySpace());
99 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
100 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
105 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
106 if (!rankedSrcType.isDynamicDim(it.index()))
110 if (resultType.isDynamicDim(it.index()))
114 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
116 builder.
create<arith::ConstantIndexOp>(loc, it.value());
118 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
119 builder.
create<cf::AssertOp>(
121 RuntimeVerifiableOpInterface::generateErrorMessage(
122 op,
"size mismatch of dim " + std::to_string(it.index())));
126 int64_t resultOffset;
128 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
132 if (resultOffset != ShapedType::kDynamic) {
134 Value srcOffset = metadataOp.getResult(1);
135 Value resultOffsetVal =
136 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
137 Value isSameOffset = builder.
create<arith::CmpIOp>(
138 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
139 builder.
create<cf::AssertOp>(
141 RuntimeVerifiableOpInterface::generateErrorMessage(
142 op,
"offset mismatch"));
148 if (it.value() == ShapedType::kDynamic)
152 metadataOp.getResult(2 + resultType.getRank() + it.index());
153 Value resultStrideVal =
154 builder.
create<arith::ConstantIndexOp>(loc, it.value());
155 Value isSameStride = builder.
create<arith::CmpIOp>(
156 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
157 builder.
create<cf::AssertOp>(
159 RuntimeVerifiableOpInterface::generateErrorMessage(
160 op,
"stride mismatch of dim " + std::to_string(it.index())));
165 struct CopyOpInterface
166 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
170 auto copyOp = cast<CopyOp>(op);
173 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
174 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
177 if (!rankedSourceType || !rankedTargetType)
180 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
181 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
184 if (!rankedSourceType.isDynamicDim(i) &&
185 !rankedTargetType.isDynamicDim(i))
187 auto getDimSize = [&](
Value memRef, MemRefType type,
188 int64_t dim) ->
Value {
189 return type.isDynamicDim(dim)
190 ? builder.
create<DimOp>(loc, memRef, dim).getResult()
192 .create<arith::ConstantIndexOp>(loc,
193 type.getDimSize(dim))
196 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
197 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
199 loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
200 builder.
create<cf::AssertOp>(
202 RuntimeVerifiableOpInterface::generateErrorMessage(
203 op,
"size of " + std::to_string(i) +
204 "-th source/target dim does not match"));
209 struct DimOpInterface
210 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
214 auto dimOp = cast<DimOp>(op);
215 Value rank = builder.
create<RankOp>(loc, dimOp.getSource());
216 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
217 builder.
create<cf::AssertOp>(
218 loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
219 RuntimeVerifiableOpInterface::generateErrorMessage(
220 op,
"index is out of bounds"));
226 template <
typename LoadStoreOp>
227 struct LoadStoreOpInterface
228 :
public RuntimeVerifiableOpInterface::ExternalModel<
229 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
232 auto loadStoreOp = cast<LoadStoreOp>(op);
234 auto memref = loadStoreOp.getMemref();
235 auto rank = memref.getType().getRank();
239 auto indices = loadStoreOp.getIndices();
241 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
243 for (
auto i : llvm::seq<int64_t>(0, rank)) {
246 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
248 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
251 builder.
create<cf::AssertOp>(
253 RuntimeVerifiableOpInterface::generateErrorMessage(
254 op,
"out-of-bounds access"));
278 return {lowerBound, upperBound};
285 auto runtimeMetadata = builder.
create<ExtractStridedMetadataOp>(loc, memref);
286 auto offset = runtimeMetadata.getConstifiedMixedOffset();
287 auto strides = runtimeMetadata.getConstifiedMixedStrides();
288 auto sizes = runtimeMetadata.getConstifiedMixedSizes();
289 return computeLinearBounds(builder, loc, offset, strides, sizes);
294 struct ReinterpretCastOpInterface
295 :
public RuntimeVerifiableOpInterface::ExternalModel<
296 ReinterpretCastOpInterface, ReinterpretCastOp> {
299 auto reinterpretCast = cast<ReinterpretCastOp>(op);
300 auto baseMemref = reinterpretCast.getSource();
302 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
307 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
310 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
314 loc, arith::CmpIPredicate::sge, low, baseLow);
318 loc, arith::CmpIPredicate::sle, high, baseHigh);
320 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
322 builder.
create<cf::AssertOp>(
324 RuntimeVerifiableOpInterface::generateErrorMessage(
326 "result of reinterpret_cast is out-of-bounds of the base memref"));
330 struct SubViewOpInterface
331 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
335 auto subView = cast<SubViewOp>(op);
336 MemRefType sourceType = subView.getSource().getType();
341 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
342 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
344 builder.
create<ExtractStridedMetadataOp>(loc, subView.getSource());
345 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
347 builder, loc, subView.getMixedOffsets()[i]);
349 subView.getMixedSizes()[i]);
351 builder, loc, subView.getMixedStrides()[i]);
354 Value dimSize = metadataOp.getSizes()[i];
355 Value offsetInBounds =
356 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
357 builder.
create<cf::AssertOp>(
359 RuntimeVerifiableOpInterface::generateErrorMessage(
360 op,
"offset " + std::to_string(i) +
" is out-of-bounds"));
363 Value sizeMinusOne = builder.
create<arith::SubIOp>(loc, size, one);
364 Value sizeMinusOneTimesStride =
365 builder.
create<arith::MulIOp>(loc, sizeMinusOne, stride);
367 builder.
create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
368 Value lastPosInBounds =
369 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
370 builder.
create<cf::AssertOp>(
371 loc, lastPosInBounds,
372 RuntimeVerifiableOpInterface::generateErrorMessage(
373 op,
"subview runs out-of-bounds along dimension " +
379 struct ExpandShapeOpInterface
380 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
384 auto expandShapeOp = cast<ExpandShapeOp>(op);
388 for (
const auto &it :
391 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
393 bool foundDynamicDim =
false;
394 for (int64_t resultDim : it.value()) {
395 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
397 assert(!foundDynamicDim &&
398 "more than one dynamic dim found in reassoc group");
399 (void)foundDynamicDim;
400 foundDynamicDim =
true;
403 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
405 Value staticResultDimSz =
406 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
409 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
411 loc, arith::CmpIPredicate::eq, mod,
412 builder.
create<arith::ConstantIndexOp>(loc, 0));
413 builder.
create<cf::AssertOp>(
415 RuntimeVerifiableOpInterface::generateErrorMessage(
416 op,
"static result dims in reassoc group do not "
417 "divide src dim evenly"));
428 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
429 CastOp::attachInterface<CastOpInterface>(*ctx);
430 CopyOp::attachInterface<CopyOpInterface>(*ctx);
431 DimOp::attachInterface<DimOpInterface>(*ctx);
432 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
433 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
434 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
435 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
436 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
439 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
440 cf::ControlFlowDialect>();
This class provides a shared interface for ranked and unranked memref types.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...