MLIR 22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
18
19using namespace mlir;
20
21namespace mlir {
22namespace tensor {
23namespace {
24/// Generate a runtime check for lb <= value < ub.
25Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
26 Value lb, Value ub) {
27 Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
28 loc, arith::CmpIPredicate::sge, value, lb);
29 Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
30 loc, arith::CmpIPredicate::slt, value, ub);
31 Value inBounds =
32 builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
33 return inBounds;
34}
35
36struct CastOpInterface
37 : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
38 CastOp> {
39 void
40 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
41 function_ref<std::string(Operation *, StringRef)>
42 generateErrorMessage) const {
43 auto castOp = cast<CastOp>(op);
44 auto srcType = cast<TensorType>(castOp.getSource().getType());
45
46 // Nothing to check if the result is an unranked tensor.
47 auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
48 if (!resultType)
49 return;
50
51 if (isa<UnrankedTensorType>(srcType)) {
52 // Check rank.
53 Value srcRank = RankOp::create(builder, loc, castOp.getSource());
54 Value resultRank =
55 arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
56 Value isSameRank = arith::CmpIOp::create(
57 builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
58 cf::AssertOp::create(builder, loc, isSameRank,
59 generateErrorMessage(op, "rank mismatch"));
60 }
61
62 // Check dimension sizes.
63 for (const auto &it : llvm::enumerate(resultType.getShape())) {
64 // Static dim size -> static/dynamic dim size does not need verification.
65 if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66 if (!rankedSrcType.isDynamicDim(it.index()))
67 continue;
68
69 // Static/dynamic dim size -> dynamic dim size does not need verification.
70 if (resultType.isDynamicDim(it.index()))
71 continue;
72
73 Value srcDimSz =
74 DimOp::create(builder, loc, castOp.getSource(), it.index());
75 Value resultDimSz =
76 arith::ConstantIndexOp::create(builder, loc, it.value());
77 Value isSameSz = arith::CmpIOp::create(
78 builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79 cf::AssertOp::create(
80 builder, loc, isSameSz,
81 generateErrorMessage(op, "size mismatch of dim " +
82 std::to_string(it.index())));
83 }
84 }
85};
86
87struct DimOpInterface
88 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
89 DimOp> {
90 void
91 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
92 function_ref<std::string(Operation *, StringRef)>
93 generateErrorMessage) const {
94 auto dimOp = cast<DimOp>(op);
95 Value rank = RankOp::create(builder, loc, dimOp.getSource());
96 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
97 cf::AssertOp::create(
98 builder, loc,
99 generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
100 generateErrorMessage(op, "index is out of bounds"));
101 }
102};
103
104/// Verifies that the indices on extract/insert ops are in-bounds of the
105/// tensor's index space: 0 <= index#i < dim#i
106template <typename OpTy>
107struct ExtractInsertOpInterface
108 : public RuntimeVerifiableOpInterface::ExternalModel<
109 ExtractInsertOpInterface<OpTy>, OpTy> {
110 void
111 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
112 function_ref<std::string(Operation *, StringRef)>
113 generateErrorMessage) const {
114 auto extractInsertOp = cast<OpTy>(op);
115
116 Value tensor;
117 if constexpr (std::is_same_v<OpTy, ExtractOp>) {
118 tensor = extractInsertOp.getTensor();
119 } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
120 tensor = extractInsertOp.getDest();
121 } else {
122 llvm_unreachable("invalid op");
123 }
124 auto tensorType = cast<RankedTensorType>(tensor.getType());
125 auto rank = tensorType.getRank();
126 if (rank == 0) {
127 // Nothing to check for 0-d tensors.
128 return;
129 }
130
131 auto indices = extractInsertOp.getIndices();
132 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
133 Value assertCond;
134 for (auto i : llvm::seq<int64_t>(0, rank)) {
135 Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
136 Value inBounds =
137 generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
138 assertCond =
139 i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
140 : inBounds;
141 }
142 cf::AssertOp::create(builder, loc, assertCond,
143 generateErrorMessage(op, "out-of-bounds access"));
144 }
145};
146
147struct ExtractSliceOpInterface
148 : public RuntimeVerifiableOpInterface::ExternalModel<
149 ExtractSliceOpInterface, ExtractSliceOp> {
150 void
151 generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
152 function_ref<std::string(Operation *, StringRef)>
153 generateErrorMessage) const {
154 auto extractSliceOp = cast<ExtractSliceOp>(op);
155 RankedTensorType sourceType = extractSliceOp.getSource().getType();
156
157 // For each dimension, assert that:
158 // For empty slices (size == 0) : 0 <= offset <= dim_size
159 // For non-empty slices (size > 0): 0 <= offset < dim_size
160 // 0 <= offset + (size - 1) * stride <
161 // dim_size
162 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
163 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
164
165 for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
166
167 builder.setInsertionPoint(extractSliceOp);
168
169 Value offset = getValueOrCreateConstantIndexOp(
170 builder, loc, extractSliceOp.getMixedOffsets()[i]);
172 builder, loc, extractSliceOp.getMixedSizes()[i]);
173 Value stride = getValueOrCreateConstantIndexOp(
174 builder, loc, extractSliceOp.getMixedStrides()[i]);
175 Value dimSize = builder.createOrFold<tensor::DimOp>(
176 loc, extractSliceOp.getSource(), i);
177
178 // Verify that offset is in-bounds (conditional on slice size).
179 Value sizeIsZero = arith::CmpIOp::create(
180 builder, loc, arith::CmpIPredicate::eq, size, zero);
181 auto offsetCheckIf = scf::IfOp::create(
182 builder, loc, sizeIsZero,
183 [&](OpBuilder &b, Location loc) {
184 // For empty slices, offset can be at the boundary: 0 <= offset <=
185 // dimSize.
186 Value offsetGEZero = arith::CmpIOp::create(
187 b, loc, arith::CmpIPredicate::sge, offset, zero);
188 Value offsetLEDimSize = arith::CmpIOp::create(
189 b, loc, arith::CmpIPredicate::sle, offset, dimSize);
190 Value emptyOffsetValid =
191 arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize);
192 scf::YieldOp::create(b, loc, emptyOffsetValid);
193 },
194 [&](OpBuilder &b, Location loc) {
195 // For non-empty slices, offset must be a valid index: 0 <= offset <
196 // dimSize.
197 Value offsetInBounds =
198 generateInBoundsCheck(b, loc, offset, zero, dimSize);
199 scf::YieldOp::create(b, loc, offsetInBounds);
200 });
201
202 Value offsetCondition = offsetCheckIf.getResult(0);
203 cf::AssertOp::create(builder, loc, offsetCondition,
204 generateErrorMessage(op, "offset " +
205 std::to_string(i) +
206 " is out-of-bounds"));
207
208 // Verify that the slice endpoint is in-bounds (only for non-empty
209 // slices).
210 Value sizeIsNonZero = arith::CmpIOp::create(
211 builder, loc, arith::CmpIPredicate::sgt, size, zero);
212 auto ifOp = scf::IfOp::create(
213 builder, loc, sizeIsNonZero,
214 [&](OpBuilder &b, Location loc) {
215 // Verify that slice does not run out-of-bounds.
216 Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one);
217 Value sizeMinusOneTimesStride =
218 arith::MulIOp::create(b, loc, sizeMinusOne, stride);
219 Value lastPos =
220 arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride);
221 Value lastPosInBounds =
222 generateInBoundsCheck(b, loc, lastPos, zero, dimSize);
223 scf::YieldOp::create(b, loc, lastPosInBounds);
224 },
225 [&](OpBuilder &b, Location loc) {
226 Value trueVal =
227 arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
228 scf::YieldOp::create(b, loc, trueVal);
229 });
230
231 Value finalCondition = ifOp.getResult(0);
232 cf::AssertOp::create(
233 builder, loc, finalCondition,
234 generateErrorMessage(
235 op, "extract_slice runs out-of-bounds along dimension " +
236 std::to_string(i)));
237 }
238 }
239};
240} // namespace
241} // namespace tensor
242} // namespace mlir
243
245 DialectRegistry &registry) {
246 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
247 CastOp::attachInterface<CastOpInterface>(*ctx);
248 DimOp::attachInterface<DimOpInterface>(*ctx);
249 ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
250 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
251 InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
252
253 // Load additional dialects of which ops may get created.
254 ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
255 });
256}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152