26 llvm::raw_string_ostream stream(buffer);
34 stream <<
"ERROR: Runtime op verification failed\n";
35 op->
print(stream, flags);
36 stream <<
"\n^ " << msg;
37 stream <<
"\nLocation: ";
45 struct CastOpInterface
46 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
50 auto castOp = cast<CastOp>(op);
51 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
54 auto resultType = dyn_cast<MemRefType>(castOp.getType());
58 if (isa<UnrankedMemRefType>(srcType)) {
60 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
62 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
64 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
65 builder.
create<cf::AssertOp>(loc, isSameRank,
73 int64_t dynamicOffset = ShapedType::kDynamic;
75 ShapedType::kDynamic);
77 dynamicOffset, dynamicShape);
80 stridedLayout, resultType.getMemorySpace());
82 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
83 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
88 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
89 if (!rankedSrcType.isDynamicDim(it.index()))
93 if (resultType.isDynamicDim(it.index()))
97 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
99 builder.
create<arith::ConstantIndexOp>(loc, it.value());
101 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
102 builder.
create<cf::AssertOp>(
105 std::to_string(it.index())));
109 int64_t resultOffset;
115 if (resultOffset != ShapedType::kDynamic) {
117 Value srcOffset = metadataOp.getResult(1);
118 Value resultOffsetVal =
119 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
120 Value isSameOffset = builder.
create<arith::CmpIOp>(
121 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
122 builder.
create<cf::AssertOp>(loc, isSameOffset,
129 if (it.value() == ShapedType::kDynamic)
133 metadataOp.getResult(2 + resultType.getRank() + it.index());
134 Value resultStrideVal =
135 builder.
create<arith::ConstantIndexOp>(loc, it.value());
136 Value isSameStride = builder.
create<arith::CmpIOp>(
137 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
138 builder.
create<cf::AssertOp>(
141 std::to_string(it.index())));
148 template <
typename LoadStoreOp>
149 struct LoadStoreOpInterface
150 :
public RuntimeVerifiableOpInterface::ExternalModel<
151 LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
154 auto loadStoreOp = cast<LoadStoreOp>(op);
156 auto memref = loadStoreOp.getMemref();
157 auto rank = memref.getType().getRank();
161 auto indices = loadStoreOp.getIndices();
163 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
165 for (
auto i : llvm::seq<int64_t>(0, rank)) {
166 auto index = indices[i];
168 auto dimOp = builder.
createOrFold<memref::DimOp>(loc, memref, i);
171 loc, arith::CmpIPredicate::sge, index, zero);
173 loc, arith::CmpIPredicate::slt, index, dimOp);
174 auto andOp = builder.
createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
177 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, andOp)
180 builder.
create<cf::AssertOp>(
205 return {lowerBound, upperBound};
212 auto runtimeMetadata = builder.
create<ExtractStridedMetadataOp>(loc, memref);
213 auto offset = runtimeMetadata.getConstifiedMixedOffset();
214 auto strides = runtimeMetadata.getConstifiedMixedStrides();
215 auto sizes = runtimeMetadata.getConstifiedMixedSizes();
216 return computeLinearBounds(builder, loc, offset, strides, sizes);
221 struct ReinterpretCastOpInterface
222 :
public RuntimeVerifiableOpInterface::ExternalModel<
223 ReinterpretCastOpInterface, ReinterpretCastOp> {
226 auto reinterpretCast = cast<ReinterpretCastOp>(op);
227 auto baseMemref = reinterpretCast.getSource();
229 cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
234 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
237 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
241 loc, arith::CmpIPredicate::sge, low, baseLow);
245 loc, arith::CmpIPredicate::sle, high, baseHigh);
247 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
249 builder.
create<cf::AssertOp>(
253 "result of reinterpret_cast is out-of-bounds of the base memref"));
267 struct SubViewOpInterface
268 :
public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
272 auto subView = cast<SubViewOp>(op);
273 auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
274 auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
279 auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
282 auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
286 loc, arith::CmpIPredicate::sge, low, baseLow);
290 loc, arith::CmpIPredicate::sle, high, baseHigh);
292 auto assertCond = builder.
createOrFold<arith::AndIOp>(loc, geLow, leHigh);
294 builder.
create<cf::AssertOp>(
297 "subview is out-of-bounds of the base memref"));
301 struct ExpandShapeOpInterface
302 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
306 auto expandShapeOp = cast<ExpandShapeOp>(op);
310 for (
const auto &it :
313 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
315 bool foundDynamicDim =
false;
316 for (int64_t resultDim : it.value()) {
317 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
319 assert(!foundDynamicDim &&
320 "more than one dynamic dim found in reassoc group");
321 (void)foundDynamicDim;
322 foundDynamicDim =
true;
325 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
327 Value staticResultDimSz =
328 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
331 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
333 loc, arith::CmpIPredicate::eq,
mod,
334 builder.
create<arith::ConstantIndexOp>(loc, 0));
335 builder.
create<cf::AssertOp>(
338 "divide src dim evenly"));
349 CastOp::attachInterface<CastOpInterface>(*ctx);
350 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
351 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
352 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
353 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
354 SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
357 ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
358 cf::ControlFlowDialect>();
static std::string generateErrorMessage(Operation *op, const std::string &msg)
Generate an error message string for the given op and the specified error.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void print(raw_ostream &os) const
Print the location.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & elideLargeElementsAttrs(int64_t largeElementLimit=16)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
OpPrintingFlags & printGenericOpForm(bool enable=true)
Always print operations in the generic form.
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.