26 struct CastOpInterface
27 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
31 auto castOp = cast<CastOp>(op);
32 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
35 auto resultType = dyn_cast<MemRefType>(castOp.getType());
39 if (isa<UnrankedMemRefType>(srcType)) {
41 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
43 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
45 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
46 builder.
create<cf::AssertOp>(
48 RuntimeVerifiableOpInterface::generateErrorMessage(op,
56 int64_t dynamicOffset = ShapedType::kDynamic;
58 ShapedType::kDynamic);
60 dynamicOffset, dynamicShape);
63 stridedLayout, resultType.getMemorySpace());
65 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
66 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
71 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
72 if (!rankedSrcType.isDynamicDim(it.index()))
76 if (resultType.isDynamicDim(it.index()))
80 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
82 builder.
create<arith::ConstantIndexOp>(loc, it.value());
84 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
85 builder.
create<cf::AssertOp>(
87 RuntimeVerifiableOpInterface::generateErrorMessage(
88 op,
"size mismatch of dim " + std::to_string(it.index())));
98 if (resultOffset != ShapedType::kDynamic) {
100 Value srcOffset = metadataOp.getResult(1);
101 Value resultOffsetVal =
102 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
103 Value isSameOffset = builder.
create<arith::CmpIOp>(
104 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
105 builder.
create<cf::AssertOp>(
107 RuntimeVerifiableOpInterface::generateErrorMessage(
108 op,
"offset mismatch"));
114 if (it.value() == ShapedType::kDynamic)
118 metadataOp.getResult(2 + resultType.getRank() + it.index());
119 Value resultStrideVal =
120 builder.
create<arith::ConstantIndexOp>(loc, it.value());
121 Value isSameStride = builder.
create<arith::CmpIOp>(
122 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
123 builder.
create<cf::AssertOp>(
125 RuntimeVerifiableOpInterface::generateErrorMessage(
126 op,
"stride mismatch of dim " + std::to_string(it.index())));
133 template <
typename LoadStoreOp>
134 struct LoadStoreOpInterface
135 :
public RuntimeVerifiableOpInterface::ExternalModel<
136 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
139 auto loadStoreOp = cast<LoadStoreOp>(op);
141 auto memref = loadStoreOp.getMemref();
142 auto rank = memref.getType().getRank();
146 auto indices = loadStoreOp.getIndices();
148 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
150 for (
auto i : llvm::seq<int64_t>(0, rank)) {
151 auto index = indices[i];
153 auto dimOp = builder.
createOrFold<memref::DimOp>(loc, memref, i);
156 loc, arith::CmpIPredicate::sge, index, zero);
158 loc, arith::CmpIPredicate::slt, index, dimOp);
159 auto andOp = builder.
createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
162 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, andOp)
165 builder.
create<cf::AssertOp>(
167 RuntimeVerifiableOpInterface::generateErrorMessage(
168 op,
"out-of-bounds access"));
192 return {lowerBound, upperBound};
199 auto runtimeMetadata = builder.
create<ExtractStridedMetadataOp>(loc, memref);
200 auto offset = runtimeMetadata.getConstifiedMixedOffset();
201 auto strides = runtimeMetadata.getConstifiedMixedStrides();
202 auto sizes = runtimeMetadata.getConstifiedMixedSizes();
203 return computeLinearBounds(builder, loc, offset, strides, sizes);
208 struct ReinterpretCastOpInterface
209 :
public RuntimeVerifiableOpInterface::ExternalModel<
210 ReinterpretCastOpInterface, ReinterpretCastOp> {
213 auto reinterpretCast = cast<ReinterpretCastOp>(op);
214 auto baseMemref = reinterpretCast.getSource();
216 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
221 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
224 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
228 loc, arith::CmpIPredicate::sge, low, baseLow);
232 loc, arith::CmpIPredicate::sle, high, baseHigh);
234 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
236 builder.
create<cf::AssertOp>(
238 RuntimeVerifiableOpInterface::generateErrorMessage(
240 "result of reinterpret_cast is out-of-bounds of the base memref"));
254 struct SubViewOpInterface
255 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
259 auto subView = cast<SubViewOp>(op);
260 auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
261 auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
266 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
269 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
273 loc, arith::CmpIPredicate::sge, low, baseLow);
277 loc, arith::CmpIPredicate::sle, high, baseHigh);
279 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
281 builder.
create<cf::AssertOp>(
283 RuntimeVerifiableOpInterface::generateErrorMessage(
284 op,
"subview is out-of-bounds of the base memref"));
288 struct ExpandShapeOpInterface
289 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
293 auto expandShapeOp = cast<ExpandShapeOp>(op);
297 for (
const auto &it :
300 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
302 bool foundDynamicDim =
false;
303 for (int64_t resultDim : it.value()) {
304 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
306 assert(!foundDynamicDim &&
307 "more than one dynamic dim found in reassoc group");
308 (void)foundDynamicDim;
309 foundDynamicDim =
true;
312 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
314 Value staticResultDimSz =
315 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
318 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
320 loc, arith::CmpIPredicate::eq, mod,
321 builder.
create<arith::ConstantIndexOp>(loc, 0));
322 builder.
create<cf::AssertOp>(
324 RuntimeVerifiableOpInterface::generateErrorMessage(
325 op,
"static result dims in reassoc group do not "
326 "divide src dim evenly"));
337 CastOp::attachInterface<CastOpInterface>(*ctx);
338 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
341 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
342 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
345 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
346 cf::ControlFlowDialect>();
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...