28 loc, arith::CmpIPredicate::sge, value, lb);
30 loc, arith::CmpIPredicate::slt, value, ub);
32 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
36 struct CastOpInterface
37 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
41 auto castOp = cast<CastOp>(op);
42 auto srcType = cast<TensorType>(castOp.getSource().getType());
45 auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
49 if (isa<UnrankedTensorType>(srcType)) {
51 Value srcRank = builder.
create<RankOp>(loc, castOp.getSource());
53 builder.
create<arith::ConstantIndexOp>(loc, resultType.getRank());
55 loc, arith::CmpIPredicate::eq, srcRank, resultRank);
56 builder.
create<cf::AssertOp>(
58 RuntimeVerifiableOpInterface::generateErrorMessage(op,
65 if (
auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66 if (!rankedSrcType.isDynamicDim(it.index()))
70 if (resultType.isDynamicDim(it.index()))
74 builder.
create<DimOp>(loc, castOp.getSource(), it.index());
76 builder.
create<arith::ConstantIndexOp>(loc, it.value());
78 loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79 builder.
create<cf::AssertOp>(
81 RuntimeVerifiableOpInterface::generateErrorMessage(
82 op,
"size mismatch of dim " + std::to_string(it.index())));
88 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
92 auto dimOp = cast<DimOp>(op);
93 Value rank = builder.
create<RankOp>(loc, dimOp.getSource());
94 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
95 builder.
create<cf::AssertOp>(
96 loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
97 RuntimeVerifiableOpInterface::generateErrorMessage(
98 op,
"index is out of bounds"));
104 template <
typename OpTy>
105 struct ExtractInsertOpInterface
106 :
public RuntimeVerifiableOpInterface::ExternalModel<
107 ExtractInsertOpInterface<OpTy>, OpTy> {
110 auto extractInsertOp = cast<OpTy>(op);
113 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
114 tensor = extractInsertOp.getTensor();
115 }
else if constexpr (std::is_same_v<OpTy, InsertOp>) {
116 tensor = extractInsertOp.getDest();
118 llvm_unreachable(
"invalid op");
120 auto tensorType = cast<RankedTensorType>(tensor.
getType());
121 auto rank = tensorType.getRank();
127 auto indices = extractInsertOp.getIndices();
128 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
130 for (
auto i : llvm::seq<int64_t>(0, rank)) {
133 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
135 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
138 builder.
create<cf::AssertOp>(
140 RuntimeVerifiableOpInterface::generateErrorMessage(
141 op,
"out-of-bounds access"));
145 struct ExtractSliceOpInterface
146 :
public RuntimeVerifiableOpInterface::ExternalModel<
147 ExtractSliceOpInterface, ExtractSliceOp> {
150 auto extractSliceOp = cast<ExtractSliceOp>(op);
151 RankedTensorType sourceType = extractSliceOp.getSource().getType();
156 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
157 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
158 for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
160 builder, loc, extractSliceOp.getMixedOffsets()[i]);
162 builder, loc, extractSliceOp.getMixedSizes()[i]);
164 builder, loc, extractSliceOp.getMixedStrides()[i]);
168 loc, extractSliceOp.getSource(), i);
169 Value offsetInBounds =
170 generateInBoundsCheck(builder, loc, offset, zero, dimSize);
171 builder.
create<cf::AssertOp>(
173 RuntimeVerifiableOpInterface::generateErrorMessage(
174 op,
"offset " + std::to_string(i) +
" is out-of-bounds"));
177 Value sizeMinusOne = builder.
create<arith::SubIOp>(loc, size, one);
178 Value sizeMinusOneTimesStride =
179 builder.
create<arith::MulIOp>(loc, sizeMinusOne, stride);
181 builder.
create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
182 Value lastPosInBounds =
183 generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
184 builder.
create<cf::AssertOp>(
185 loc, lastPosInBounds,
186 RuntimeVerifiableOpInterface::generateErrorMessage(
187 op,
"extract_slice runs out-of-bounds along dimension " +
199 CastOp::attachInterface<CastOpInterface>(*ctx);
200 DimOp::attachInterface<DimOpInterface>(*ctx);
201 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
202 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
203 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
206 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.