27 loc, arith::CmpIPredicate::sge, value, lb);
29 loc, arith::CmpIPredicate::slt, value, ub);
31 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35 struct CastOpInterface
36 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
40 auto castOp = cast<CastOp>(op);
41 auto srcType = cast<TensorType>(castOp.getSource().getType());
44 auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
48 if (isa<UnrankedTensorType>(srcType)) {
50 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
53 Value isSameRank = arith::CmpIOp::create(
54 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
55 cf::AssertOp::create(builder, loc, isSameRank,
56 RuntimeVerifiableOpInterface::generateErrorMessage(
57 op,
"rank mismatch"));
63 if (
auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
64 if (!rankedSrcType.isDynamicDim(it.index()))
68 if (resultType.isDynamicDim(it.index()))
72 DimOp::create(builder, loc, castOp.getSource(), it.index());
75 Value isSameSz = arith::CmpIOp::create(
76 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
78 builder, loc, isSameSz,
79 RuntimeVerifiableOpInterface::generateErrorMessage(
80 op,
"size mismatch of dim " + std::to_string(it.index())));
86 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
90 auto dimOp = cast<DimOp>(op);
91 Value rank = RankOp::create(builder, loc, dimOp.getSource());
95 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
96 RuntimeVerifiableOpInterface::generateErrorMessage(
97 op,
"index is out of bounds"));
103 template <
typename OpTy>
104 struct ExtractInsertOpInterface
105 :
public RuntimeVerifiableOpInterface::ExternalModel<
106 ExtractInsertOpInterface<OpTy>, OpTy> {
109 auto extractInsertOp = cast<OpTy>(op);
112 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
113 tensor = extractInsertOp.getTensor();
114 }
else if constexpr (std::is_same_v<OpTy, InsertOp>) {
115 tensor = extractInsertOp.getDest();
117 llvm_unreachable(
"invalid op");
119 auto tensorType = cast<RankedTensorType>(tensor.
getType());
120 auto rank = tensorType.getRank();
126 auto indices = extractInsertOp.getIndices();
129 for (
auto i : llvm::seq<int64_t>(0, rank)) {
132 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
134 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
137 cf::AssertOp::create(builder, loc, assertCond,
138 RuntimeVerifiableOpInterface::generateErrorMessage(
139 op,
"out-of-bounds access"));
143 struct ExtractSliceOpInterface
144 :
public RuntimeVerifiableOpInterface::ExternalModel<
145 ExtractSliceOpInterface, ExtractSliceOp> {
148 auto extractSliceOp = cast<ExtractSliceOp>(op);
149 RankedTensorType sourceType = extractSliceOp.getSource().getType();
156 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
158 builder, loc, extractSliceOp.getMixedOffsets()[i]);
160 builder, loc, extractSliceOp.getMixedSizes()[i]);
162 builder, loc, extractSliceOp.getMixedStrides()[i]);
166 loc, extractSliceOp.getSource(), i);
167 Value offsetInBounds =
168 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
169 cf::AssertOp::create(
170 builder, loc, offsetInBounds,
171 RuntimeVerifiableOpInterface::generateErrorMessage(
172 op,
"offset " + std::to_string(i) +
" is out-of-bounds"));
175 Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
176 Value sizeMinusOneTimesStride =
177 arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
179 arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
180 Value lastPosInBounds =
181 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
182 cf::AssertOp::create(
183 builder, loc, lastPosInBounds,
184 RuntimeVerifiableOpInterface::generateErrorMessage(
185 op,
"extract_slice runs out-of-bounds along dimension " +
197 CastOp::attachInterface<CastOpInterface>(*ctx);
198 DimOp::attachInterface<DimOpInterface>(*ctx);
199 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
200 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
201 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
204 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.