MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace tensor {
23 namespace {
24 /// Generate a runtime check for lb <= value < ub.
25 Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
26  Value lb, Value ub) {
27  Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
28  loc, arith::CmpIPredicate::sge, value, lb);
29  Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
30  loc, arith::CmpIPredicate::slt, value, ub);
31  Value inBounds =
32  builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
33  return inBounds;
34 }
35 
36 struct CastOpInterface
37  : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
38  CastOp> {
39  void
40  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
41  function_ref<std::string(Operation *, StringRef)>
42  generateErrorMessage) const {
43  auto castOp = cast<CastOp>(op);
44  auto srcType = cast<TensorType>(castOp.getSource().getType());
45 
46  // Nothing to check if the result is an unranked tensor.
47  auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
48  if (!resultType)
49  return;
50 
51  if (isa<UnrankedTensorType>(srcType)) {
52  // Check rank.
53  Value srcRank = RankOp::create(builder, loc, castOp.getSource());
54  Value resultRank =
55  arith::ConstantIndexOp::create(builder, loc, resultType.getRank());
56  Value isSameRank = arith::CmpIOp::create(
57  builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
58  cf::AssertOp::create(builder, loc, isSameRank,
59  generateErrorMessage(op, "rank mismatch"));
60  }
61 
62  // Check dimension sizes.
63  for (const auto &it : llvm::enumerate(resultType.getShape())) {
64  // Static dim size -> static/dynamic dim size does not need verification.
65  if (auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
66  if (!rankedSrcType.isDynamicDim(it.index()))
67  continue;
68 
69  // Static/dynamic dim size -> dynamic dim size does not need verification.
70  if (resultType.isDynamicDim(it.index()))
71  continue;
72 
73  Value srcDimSz =
74  DimOp::create(builder, loc, castOp.getSource(), it.index());
75  Value resultDimSz =
76  arith::ConstantIndexOp::create(builder, loc, it.value());
77  Value isSameSz = arith::CmpIOp::create(
78  builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
79  cf::AssertOp::create(
80  builder, loc, isSameSz,
81  generateErrorMessage(op, "size mismatch of dim " +
82  std::to_string(it.index())));
83  }
84  }
85 };
86 
87 struct DimOpInterface
88  : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
89  DimOp> {
90  void
91  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
92  function_ref<std::string(Operation *, StringRef)>
93  generateErrorMessage) const {
94  auto dimOp = cast<DimOp>(op);
95  Value rank = RankOp::create(builder, loc, dimOp.getSource());
96  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
97  cf::AssertOp::create(
98  builder, loc,
99  generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
100  generateErrorMessage(op, "index is out of bounds"));
101  }
102 };
103 
104 /// Verifies that the indices on extract/insert ops are in-bounds of the
105 /// tensor's index space: 0 <= index#i < dim#i
106 template <typename OpTy>
107 struct ExtractInsertOpInterface
108  : public RuntimeVerifiableOpInterface::ExternalModel<
109  ExtractInsertOpInterface<OpTy>, OpTy> {
110  void
111  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
112  function_ref<std::string(Operation *, StringRef)>
113  generateErrorMessage) const {
114  auto extractInsertOp = cast<OpTy>(op);
115 
116  Value tensor;
117  if constexpr (std::is_same_v<OpTy, ExtractOp>) {
118  tensor = extractInsertOp.getTensor();
119  } else if constexpr (std::is_same_v<OpTy, InsertOp>) {
120  tensor = extractInsertOp.getDest();
121  } else {
122  llvm_unreachable("invalid op");
123  }
124  auto tensorType = cast<RankedTensorType>(tensor.getType());
125  auto rank = tensorType.getRank();
126  if (rank == 0) {
127  // Nothing to check for 0-d tensors.
128  return;
129  }
130 
131  auto indices = extractInsertOp.getIndices();
132  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
133  Value assertCond;
134  for (auto i : llvm::seq<int64_t>(0, rank)) {
135  Value dimOp = builder.createOrFold<tensor::DimOp>(loc, tensor, i);
136  Value inBounds =
137  generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
138  assertCond =
139  i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
140  : inBounds;
141  }
142  cf::AssertOp::create(builder, loc, assertCond,
143  generateErrorMessage(op, "out-of-bounds access"));
144  }
145 };
146 
147 struct ExtractSliceOpInterface
148  : public RuntimeVerifiableOpInterface::ExternalModel<
149  ExtractSliceOpInterface, ExtractSliceOp> {
150  void
151  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
152  function_ref<std::string(Operation *, StringRef)>
153  generateErrorMessage) const {
154  auto extractSliceOp = cast<ExtractSliceOp>(op);
155  RankedTensorType sourceType = extractSliceOp.getSource().getType();
156 
157  // For each dimension, assert that:
158  // 0 <= offset < dim_size
159  // 0 <= offset + (size - 1) * stride < dim_size
160  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
161  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
162 
163  for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
164  // Reset insertion point to before the operation for each dimension
165  builder.setInsertionPoint(extractSliceOp);
166 
168  builder, loc, extractSliceOp.getMixedOffsets()[i]);
170  builder, loc, extractSliceOp.getMixedSizes()[i]);
172  builder, loc, extractSliceOp.getMixedStrides()[i]);
173 
174  // Verify that offset is in-bounds.
175  Value dimSize = builder.createOrFold<tensor::DimOp>(
176  loc, extractSliceOp.getSource(), i);
177  Value offsetInBounds =
178  generateInBoundsCheck(builder, loc, offset, zero, dimSize);
179  cf::AssertOp::create(builder, loc, offsetInBounds,
180  generateErrorMessage(op, "offset " +
181  std::to_string(i) +
182  " is out-of-bounds"));
183 
184  // Only verify if size > 0
185  Value sizeIsNonZero = arith::CmpIOp::create(
186  builder, loc, arith::CmpIPredicate::sgt, size, zero);
187 
188  auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
189  sizeIsNonZero, /*withElseRegion=*/true);
190 
191  // Populate the "then" region (for size > 0).
192  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
193 
194  // Verify that slice does not run out-of-bounds.
195  Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
196  Value sizeMinusOneTimesStride =
197  arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
198  Value lastPos =
199  arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
200  Value lastPosInBounds =
201  generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
202  scf::YieldOp::create(builder, loc, lastPosInBounds);
203 
204  // Populate the "else" region (for size == 0).
205  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
206  Value trueVal =
207  arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
208  scf::YieldOp::create(builder, loc, trueVal);
209 
210  builder.setInsertionPointAfter(ifOp);
211  Value finalCondition = ifOp.getResult(0);
212 
213  cf::AssertOp::create(
214  builder, loc, finalCondition,
215  generateErrorMessage(
216  op, "extract_slice runs out-of-bounds along dimension " +
217  std::to_string(i)));
218  }
219  }
220 };
221 } // namespace
222 } // namespace tensor
223 } // namespace mlir
224 
226  DialectRegistry &registry) {
227  registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
228  CastOp::attachInterface<CastOpInterface>(*ctx);
229  DimOp::attachInterface<DimOpInterface>(*ctx);
230  ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
231  ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
232  InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
233 
234  // Load additional dialects of which ops may get created.
235  ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
236  });
237 }
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
IntegerType getI1Type()
Definition: Builders.cpp:53
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111