26 struct CastOpInterface
27 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
31 auto castOp = cast<CastOp>(op);
32 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
35 auto resultType = dyn_cast<MemRefType>(castOp.getType());
39 if (isa<UnrankedMemRefType>(srcType)) {
41 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
43 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
45 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46 builder.
create<cf::AssertOp>(
48 RuntimeVerifiableOpInterface::generateErrorMessage(op,
56 int64_t dynamicOffset = ShapedType::kDynamic;
58 ShapedType::kDynamic);
60 dynamicOffset, dynamicShape);
63 stridedLayout, resultType.getMemorySpace());
65 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
66 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
71 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
72 if (!rankedSrcType.isDynamicDim(it.index()))
76 if (resultType.isDynamicDim(it.index()))
80 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
82 builder.
create<arith::ConstantIndexOp>(loc, it.value());
84 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
85 builder.
create<cf::AssertOp>(
87 RuntimeVerifiableOpInterface::generateErrorMessage(
88 op,
"size mismatch of dim " + std::to_string(it.index())));
94 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
98 if (resultOffset != ShapedType::kDynamic) {
100 Value srcOffset = metadataOp.getResult(1);
101 Value resultOffsetVal =
102 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
103 Value isSameOffset = builder.
create<arith::CmpIOp>(
104 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105 builder.
create<cf::AssertOp>(
107 RuntimeVerifiableOpInterface::generateErrorMessage(
108 op,
"offset mismatch"));
114 if (it.value() == ShapedType::kDynamic)
118 metadataOp.getResult(2 + resultType.getRank() + it.index());
119 Value resultStrideVal =
120 builder.
create<arith::ConstantIndexOp>(loc, it.value());
121 Value isSameStride = builder.
create<arith::CmpIOp>(
122 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
123 builder.
create<cf::AssertOp>(
125 RuntimeVerifiableOpInterface::generateErrorMessage(
126 op,
"stride mismatch of dim " + std::to_string(it.index())));
131 struct CopyOpInterface
132 :
public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
136 auto copyOp = cast<CopyOp>(op);
139 auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
140 auto rankedTargetType = dyn_cast<MemRefType>(targetType);
143 if (!rankedSourceType || !rankedTargetType)
146 assert(sourceType.getRank() == targetType.getRank() &&
"rank mismatch");
147 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
150 if (!rankedSourceType.isDynamicDim(i) &&
151 !rankedTargetType.isDynamicDim(i))
153 auto getDimSize = [&](
Value memRef, MemRefType type,
154 int64_t dim) ->
Value {
155 return type.isDynamicDim(dim)
156 ? builder.
create<DimOp>(loc, memRef, dim).getResult()
158 .create<arith::ConstantIndexOp>(loc,
159 type.getDimSize(dim))
162 Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i);
163 Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
165 loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
166 builder.
create<cf::AssertOp>(
168 RuntimeVerifiableOpInterface::generateErrorMessage(
169 op,
"size of " + std::to_string(i) +
170 "-th source/target dim does not match"));
177 template <
typename LoadStoreOp>
178 struct LoadStoreOpInterface
179 :
public RuntimeVerifiableOpInterface::ExternalModel<
180 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
183 auto loadStoreOp = cast<LoadStoreOp>(op);
185 auto memref = loadStoreOp.getMemref();
186 auto rank = memref.getType().getRank();
190 auto indices = loadStoreOp.getIndices();
192 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
194 for (
auto i : llvm::seq<int64_t>(0, rank)) {
195 auto index = indices[i];
197 auto dimOp = builder.
createOrFold<memref::DimOp>(loc, memref, i);
200 loc, arith::CmpIPredicate::sge, index, zero);
202 loc, arith::CmpIPredicate::slt, index, dimOp);
203 auto andOp = builder.
createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
206 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, andOp)
209 builder.
create<cf::AssertOp>(
211 RuntimeVerifiableOpInterface::generateErrorMessage(
212 op,
"out-of-bounds access"));
236 return {lowerBound, upperBound};
243 auto runtimeMetadata = builder.
create<ExtractStridedMetadataOp>(loc, memref);
244 auto offset = runtimeMetadata.getConstifiedMixedOffset();
245 auto strides = runtimeMetadata.getConstifiedMixedStrides();
246 auto sizes = runtimeMetadata.getConstifiedMixedSizes();
247 return computeLinearBounds(builder, loc, offset, strides, sizes);
252 struct ReinterpretCastOpInterface
253 :
public RuntimeVerifiableOpInterface::ExternalModel<
254 ReinterpretCastOpInterface, ReinterpretCastOp> {
257 auto reinterpretCast = cast<ReinterpretCastOp>(op);
258 auto baseMemref = reinterpretCast.getSource();
260 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
265 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
268 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
272 loc, arith::CmpIPredicate::sge, low, baseLow);
276 loc, arith::CmpIPredicate::sle, high, baseHigh);
278 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
280 builder.
create<cf::AssertOp>(
282 RuntimeVerifiableOpInterface::generateErrorMessage(
284 "result of reinterpret_cast is out-of-bounds of the base memref"));
298 struct SubViewOpInterface
299 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
303 auto subView = cast<SubViewOp>(op);
304 auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
305 auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
310 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
313 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
317 loc, arith::CmpIPredicate::sge, low, baseLow);
321 loc, arith::CmpIPredicate::sle, high, baseHigh);
323 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
325 builder.
create<cf::AssertOp>(
327 RuntimeVerifiableOpInterface::generateErrorMessage(
328 op,
"subview is out-of-bounds of the base memref"));
332 struct ExpandShapeOpInterface
333 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
337 auto expandShapeOp = cast<ExpandShapeOp>(op);
341 for (
const auto &it :
344 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
346 bool foundDynamicDim =
false;
347 for (int64_t resultDim : it.value()) {
348 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
350 assert(!foundDynamicDim &&
351 "more than one dynamic dim found in reassoc group");
352 (void)foundDynamicDim;
353 foundDynamicDim =
true;
356 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
358 Value staticResultDimSz =
359 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
362 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
364 loc, arith::CmpIPredicate::eq, mod,
365 builder.
create<arith::ConstantIndexOp>(loc, 0));
366 builder.
create<cf::AssertOp>(
368 RuntimeVerifiableOpInterface::generateErrorMessage(
369 op,
"static result dims in reassoc group do not "
370 "divide src dim evenly"));
381 CastOp::attachInterface<CastOpInterface>(*ctx);
382 CopyOp::attachInterface<CopyOpInterface>(*ctx);
383 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
384 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
385 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
386 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
387 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
390 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
391 cf::ControlFlowDialect>();
This class provides a shared interface for ranked and unranked memref types.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...