25Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28 loc, arith::CmpIPredicate::sge, value, lb);
30 loc, arith::CmpIPredicate::slt, value, ub);
32 builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
37 :
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
40 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
42 generateErrorMessage)
const {
43 auto castOp = cast<CastOp>(op);
44 auto srcType = cast<TensorType>(castOp.getSource().getType());
47 auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
51 if (isa<UnrankedTensorType>(srcType)) {
53 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
56 Value isSameRank = arith::CmpIOp::create(
57 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
58 cf::AssertOp::create(builder, loc, isSameRank,
59 generateErrorMessage(op,
"rank mismatch"));
63 for (
const auto &it : llvm::enumerate(resultType.getShape())) {
65 if (
auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66 if (!rankedSrcType.isDynamicDim(it.index()))
70 if (resultType.isDynamicDim(it.index()))
74 DimOp::create(builder, loc, castOp.getSource(), it.index());
77 Value isSameSz = arith::CmpIOp::create(
78 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
80 builder, loc, isSameSz,
81 generateErrorMessage(op,
"size mismatch of dim " +
82 std::to_string(it.index())));
88 :
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
91 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
93 generateErrorMessage)
const {
94 auto dimOp = cast<DimOp>(op);
95 Value rank = RankOp::create(builder, loc, dimOp.getSource());
99 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
100 generateErrorMessage(op,
"index is out of bounds"));
106template <
typename OpTy>
107struct ExtractInsertOpInterface
108 :
public RuntimeVerifiableOpInterface::ExternalModel<
109 ExtractInsertOpInterface<OpTy>, OpTy> {
111 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
113 generateErrorMessage)
const {
114 auto extractInsertOp = cast<OpTy>(op);
117 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
118 tensor = extractInsertOp.getTensor();
119 }
else if constexpr (std::is_same_v<OpTy, InsertOp>) {
120 tensor = extractInsertOp.getDest();
122 llvm_unreachable(
"invalid op");
124 auto tensorType = cast<RankedTensorType>(tensor.
getType());
125 auto rank = tensorType.getRank();
131 auto indices = extractInsertOp.getIndices();
134 for (
auto i : llvm::seq<int64_t>(0, rank)) {
135 Value dimOp = builder.
createOrFold<tensor::DimOp>(loc, tensor, i);
137 generateInBoundsCheck(builder, loc,
indices[i], zero, dimOp);
139 i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
142 cf::AssertOp::create(builder, loc, assertCond,
143 generateErrorMessage(op,
"out-of-bounds access"));
147struct ExtractSliceOpInterface
148 :
public RuntimeVerifiableOpInterface::ExternalModel<
149 ExtractSliceOpInterface, ExtractSliceOp> {
151 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
153 generateErrorMessage)
const {
154 auto extractSliceOp = cast<ExtractSliceOp>(op);
155 RankedTensorType sourceType = extractSliceOp.getSource().getType();
165 for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
170 builder, loc, extractSliceOp.getMixedOffsets()[i]);
172 builder, loc, extractSliceOp.getMixedSizes()[i]);
174 builder, loc, extractSliceOp.getMixedStrides()[i]);
176 loc, extractSliceOp.getSource(), i);
179 Value sizeIsZero = arith::CmpIOp::create(
180 builder, loc, arith::CmpIPredicate::eq, size, zero);
181 auto offsetCheckIf = scf::IfOp::create(
182 builder, loc, sizeIsZero,
183 [&](OpBuilder &
b, Location loc) {
186 Value offsetGEZero = arith::CmpIOp::create(
187 b, loc, arith::CmpIPredicate::sge, offset, zero);
188 Value offsetLEDimSize = arith::CmpIOp::create(
189 b, loc, arith::CmpIPredicate::sle, offset, dimSize);
190 Value emptyOffsetValid =
191 arith::AndIOp::create(
b, loc, offsetGEZero, offsetLEDimSize);
192 scf::YieldOp::create(
b, loc, emptyOffsetValid);
194 [&](OpBuilder &
b, Location loc) {
197 Value offsetInBounds =
198 generateInBoundsCheck(
b, loc, offset, zero, dimSize);
199 scf::YieldOp::create(
b, loc, offsetInBounds);
202 Value offsetCondition = offsetCheckIf.getResult(0);
203 cf::AssertOp::create(builder, loc, offsetCondition,
204 generateErrorMessage(op,
"offset " +
206 " is out-of-bounds"));
210 Value sizeIsNonZero = arith::CmpIOp::create(
211 builder, loc, arith::CmpIPredicate::sgt, size, zero);
212 auto ifOp = scf::IfOp::create(
213 builder, loc, sizeIsNonZero,
214 [&](OpBuilder &
b, Location loc) {
216 Value sizeMinusOne = arith::SubIOp::create(
b, loc, size, one);
217 Value sizeMinusOneTimesStride =
218 arith::MulIOp::create(
b, loc, sizeMinusOne, stride);
220 arith::AddIOp::create(
b, loc, offset, sizeMinusOneTimesStride);
221 Value lastPosInBounds =
222 generateInBoundsCheck(
b, loc, lastPos, zero, dimSize);
223 scf::YieldOp::create(
b, loc, lastPosInBounds);
225 [&](OpBuilder &
b, Location loc) {
227 arith::ConstantOp::create(
b, loc,
b.getBoolAttr(
true));
228 scf::YieldOp::create(
b, loc, trueVal);
231 Value finalCondition = ifOp.getResult(0);
232 cf::AssertOp::create(
233 builder, loc, finalCondition,
234 generateErrorMessage(
235 op,
"extract_slice runs out-of-bounds along dimension " +
247 CastOp::attachInterface<CastOpInterface>(*ctx);
248 DimOp::attachInterface<DimOpInterface>(*ctx);
249 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
250 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
251 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
254 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::function_ref< Fn > function_ref