28       loc, arith::CmpIPredicate::sge, value, lb);
 
   30       loc, arith::CmpIPredicate::slt, value, ub);
 
   32       builder.
createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
 
   36 struct CastOpInterface
 
   37     : 
public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
 
   42                                   generateErrorMessage)
 const {
 
   43     auto castOp = cast<CastOp>(op);
 
   44     auto srcType = cast<TensorType>(castOp.getSource().getType());
 
   47     auto resultType = dyn_cast<RankedTensorType>(castOp.getType());
 
   51     if (isa<UnrankedTensorType>(srcType)) {
 
   53       Value srcRank = RankOp::create(builder, loc, castOp.getSource());
 
   56       Value isSameRank = arith::CmpIOp::create(
 
   57           builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
 
   58       cf::AssertOp::create(builder, loc, isSameRank,
 
   59                            generateErrorMessage(op, 
"rank mismatch"));
 
   65       if (
auto rankedSrcType = dyn_cast<RankedTensorType>(srcType))
 
   66         if (!rankedSrcType.isDynamicDim(it.index()))
 
   70       if (resultType.isDynamicDim(it.index()))
 
   74           DimOp::create(builder, loc, castOp.getSource(), it.index());
 
   77       Value isSameSz = arith::CmpIOp::create(
 
   78           builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
 
   80           builder, loc, isSameSz,
 
   81           generateErrorMessage(op, 
"size mismatch of dim " +
 
   82                                        std::to_string(it.index())));
 
   88     : 
public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
 
   93                                   generateErrorMessage)
 const {
 
   94     auto dimOp = cast<DimOp>(op);
 
   95     Value rank = RankOp::create(builder, loc, dimOp.getSource());
 
   99         generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
 
  100         generateErrorMessage(op, 
"index is out of bounds"));
 
  106 template <
typename OpTy>
 
  107 struct ExtractInsertOpInterface
 
  108     : 
public RuntimeVerifiableOpInterface::ExternalModel<
 
  109           ExtractInsertOpInterface<OpTy>, OpTy> {
 
  113                                   generateErrorMessage)
 const {
 
  114     auto extractInsertOp = cast<OpTy>(op);
 
  117     if constexpr (std::is_same_v<OpTy, ExtractOp>) {
 
  118       tensor = extractInsertOp.getTensor();
 
  119     } 
else if constexpr (std::is_same_v<OpTy, InsertOp>) {
 
  120       tensor = extractInsertOp.getDest();
 
  122       llvm_unreachable(
"invalid op");
 
  124     auto tensorType = cast<RankedTensorType>(tensor.
getType());
 
  125     auto rank = tensorType.getRank();
 
  131     auto indices = extractInsertOp.getIndices();
 
  134     for (
auto i : llvm::seq<int64_t>(0, rank)) {
 
  137           generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
 
  139           i > 0 ? builder.
createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
 
  142     cf::AssertOp::create(builder, loc, assertCond,
 
  143                          generateErrorMessage(op, 
"out-of-bounds access"));
 
  147 struct ExtractSliceOpInterface
 
  148     : 
public RuntimeVerifiableOpInterface::ExternalModel<
 
  149           ExtractSliceOpInterface, ExtractSliceOp> {
 
  153                                   generateErrorMessage)
 const {
 
  154     auto extractSliceOp = cast<ExtractSliceOp>(op);
 
  155     RankedTensorType sourceType = extractSliceOp.getSource().getType();
 
  163     for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
 
  168           builder, loc, extractSliceOp.getMixedOffsets()[i]);
 
  170           builder, loc, extractSliceOp.getMixedSizes()[i]);
 
  172           builder, loc, extractSliceOp.getMixedStrides()[i]);
 
  176           loc, extractSliceOp.getSource(), i);
 
  177       Value offsetInBounds =
 
  178           generateInBoundsCheck(builder, loc, offset, zero, dimSize);
 
  179       cf::AssertOp::create(builder, loc, offsetInBounds,
 
  180                            generateErrorMessage(op, 
"offset " +
 
  182                                                         " is out-of-bounds"));
 
  185       Value sizeIsNonZero = arith::CmpIOp::create(
 
  186           builder, loc, arith::CmpIPredicate::sgt, size, zero);
 
  188       auto ifOp = scf::IfOp::create(builder, loc, builder.
getI1Type(),
 
  189                                     sizeIsNonZero, 
true);
 
  195       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
 
  196       Value sizeMinusOneTimesStride =
 
  197           arith::MulIOp::create(builder, loc, sizeMinusOne, stride);
 
  199           arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
 
  200       Value lastPosInBounds =
 
  201           generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
 
  202       scf::YieldOp::create(builder, loc, lastPosInBounds);
 
  207           arith::ConstantOp::create(builder, loc, builder.
getBoolAttr(
true));
 
  208       scf::YieldOp::create(builder, loc, trueVal);
 
  211       Value finalCondition = ifOp.getResult(0);
 
  213       cf::AssertOp::create(
 
  214           builder, loc, finalCondition,
 
  215           generateErrorMessage(
 
  216               op, 
"extract_slice runs out-of-bounds along dimension " +
 
  228     CastOp::attachInterface<CastOpInterface>(*ctx);
 
  229     DimOp::attachInterface<DimOpInterface>(*ctx);
 
  230     ExtractOp::attachInterface<ExtractInsertOpInterface<ExtractOp>>(*ctx);
 
  231     ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
 
  232     InsertOp::attachInterface<ExtractInsertOpInterface<InsertOp>>(*ctx);
 
  235     ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
 
BoolAttr getBoolAttr(bool value)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry ®istry)
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.