22 llvm::raw_string_ostream stream(buffer);
24 stream <<
"ERROR: Runtime op verification failed\n";
25 op->
print(stream, flags);
26 stream <<
"\n^ " << msg;
27 stream <<
"\nLocation: ";
35 struct CastOpInterface
36 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
40 auto castOp = cast<CastOp>(op);
41 auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
44 auto resultType = dyn_cast<MemRefType>(castOp.getType());
48 if (isa<UnrankedMemRefType>(srcType)) {
50 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
52 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
54 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
55 builder.
create<cf::AssertOp>(loc, isSameRank,
63 int64_t dynamicOffset = ShapedType::kDynamic;
65 ShapedType::kDynamic);
67 dynamicOffset, dynamicShape);
70 stridedLayout, resultType.getMemorySpace());
72 builder.
create<CastOp>(loc, dynStridesType, castOp.getSource());
73 auto metadataOp = builder.
create<ExtractStridedMetadataOp>(loc, helperCast);
78 if (
auto rankedSrcType = dyn_cast<MemRefType>(srcType))
79 if (!rankedSrcType.isDynamicDim(it.index()))
83 if (resultType.isDynamicDim(it.index()))
87 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
89 builder.
create<arith::ConstantIndexOp>(loc, it.value());
91 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
92 builder.
create<cf::AssertOp>(
95 std::to_string(it.index())));
105 if (resultOffset != ShapedType::kDynamic) {
107 Value srcOffset = metadataOp.getResult(1);
108 Value resultOffsetVal =
109 builder.
create<arith::ConstantIndexOp>(loc, resultOffset);
110 Value isSameOffset = builder.
create<arith::CmpIOp>(
111 loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
112 builder.
create<cf::AssertOp>(loc, isSameOffset,
119 if (it.value() == ShapedType::kDynamic)
123 metadataOp.getResult(2 + resultType.getRank() + it.index());
124 Value resultStrideVal =
125 builder.
create<arith::ConstantIndexOp>(loc, it.value());
126 Value isSameStride = builder.
create<arith::CmpIOp>(
127 loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
128 builder.
create<cf::AssertOp>(
131 std::to_string(it.index())));
136 struct ExpandShapeOpInterface
137 :
public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
141 auto expandShapeOp = cast<ExpandShapeOp>(op);
145 for (
const auto &it :
148 builder.
create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
150 bool foundDynamicDim =
false;
151 for (int64_t resultDim : it.value()) {
152 if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
154 assert(!foundDynamicDim &&
155 "more than one dynamic dim found in reassoc group");
156 (void)foundDynamicDim;
157 foundDynamicDim =
true;
160 groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
162 Value staticResultDimSz =
163 builder.
create<arith::ConstantIndexOp>(loc, groupSz);
166 builder.
create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
168 loc, arith::CmpIPredicate::eq,
mod,
169 builder.
create<arith::ConstantIndexOp>(loc, 0));
170 builder.
create<cf::AssertOp>(
173 "divide src dim evenly"));
184 CastOp::attachInterface<CastOpInterface>(*ctx);
185 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
188 ctx->
loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
static std::string generateErrorMessage(Operation *op, const std::string &msg)
Generate an error message string for the given op and the specified error.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void print(raw_ostream &os) const
Print the location.
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Set of flags used to control the behavior of the various IR print methods (e.g.
Operation is the basic unit of execution within MLIR.
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
Location getLoc()
The source location the operation was defined or derived from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.