25struct LdMatrixOpInterface final
26 : IndexedAccessOpInterface::ExternalModel<LdMatrixOpInterface, LdMatrixOp> {
28 return cast<LdMatrixOp>(op).getSrcMemref();
32 return cast<LdMatrixOp>(op).getIndices();
35 SmallVector<int64_t> getAccessedShape(Operation *op)
const {
36 VectorType vecTy = cast<LdMatrixOp>(op).getRes().
getType();
40 return SmallVector<int64_t>{vecTy.getNumElements()};
43 std::optional<SmallVector<Value>>
44 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
46 auto ldMatrixOp = cast<LdMatrixOp>(op);
48 ldMatrixOp.getSrcMemrefMutable().assign(newMemref);
49 ldMatrixOp.getIndicesMutable().assign(newIndices);
54 bool hasInboundsIndices(Operation *)
const {
return true; }
57struct DeviceAsyncCopyOpInterface final
58 : IndexedMemCopyOpInterface::ExternalModel<DeviceAsyncCopyOpInterface,
61 return cast<DeviceAsyncCopyOp>(op).getSrc();
65 return cast<DeviceAsyncCopyOp>(op).getSrcIndices();
69 return cast<DeviceAsyncCopyOp>(op).getDst();
73 return cast<DeviceAsyncCopyOp>(op).getDstIndices();
76 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
79 auto copyOp = cast<DeviceAsyncCopyOp>(op);
81 copyOp.getSrcMutable().assign(newSrc);
82 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
83 copyOp.getDstMutable().assign(newDst);
84 copyOp.getDstIndicesMutable().assign(newDstIndices);
93 LdMatrixOp::attachInterface<LdMatrixOpInterface>(*ctx);
94 DeviceAsyncCopyOp::attachInterface<DeviceAsyncCopyOpInterface>(*ctx);
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
type_range getType() const
OperandRange operand_range
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void registerMemoryAccessOpInterfacesExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.