38using ConvertIndexShl =
40using ConvertIndexShrS =
42using ConvertIndexShrU =
60struct ConvertIndexConstantBoolOpPattern final
61 : OpConversionPattern<BoolConstantOp> {
65 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter)
const override {
67 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
79struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
83 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter)
const override {
85 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
86 Type indexType = typeConverter->getIndexType();
88 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
89 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
90 op, indexType, IntegerAttr::get(indexType, value));
102struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
106 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter)
const override {
108 Location loc = op.getLoc();
109 Value n = adaptor.getLhs();
111 Value m = adaptor.getRhs();
114 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
115 IntegerAttr::get(n_type, 0));
116 Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
117 IntegerAttr::get(n_type, 1));
118 Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
119 IntegerAttr::get(n_type, -1));
122 Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
123 Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);
126 Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
127 Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
128 Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);
131 Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
132 Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
133 Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);
137 Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
138 Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
139 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
140 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
141 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
152struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
156 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter)
const override {
158 Location loc = op.getLoc();
159 Value n = adaptor.getLhs();
161 Value m = adaptor.getRhs();
164 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
165 IntegerAttr::get(n_type, 0));
166 Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
167 IntegerAttr::get(n_type, 1));
170 Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
171 Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
172 Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);
175 Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
176 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
188struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
192 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter)
const override {
194 Location loc = op.getLoc();
195 Value n = adaptor.getLhs();
197 Value m = adaptor.getRhs();
200 Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
201 IntegerAttr::get(n_type, 0));
202 Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
203 IntegerAttr::get(n_type, 1));
204 Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
205 IntegerAttr::get(n_type, -1));
208 Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
209 Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);
212 Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
213 Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
214 Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);
217 Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);
221 Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
223 spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
224 Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
226 Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
240template <
typename CastOp,
typename ConvertOp>
241struct ConvertIndexCast final : OpConversionPattern<CastOp> {
242 using OpConversionPattern<CastOp>::OpConversionPattern;
245 matchAndRewrite(CastOp op,
typename CastOp::Adaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(srcType)) {
255 if (isa<IndexType>(dstType)) {
259 if (srcType == dstType) {
260 rewriter.replaceOp(op, adaptor.getInput());
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
269using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
277template <
typename ICmpOp>
278static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) {
280 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
284struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter)
const override {
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
313 llvm_unreachable(
"Unknown predicate in ConvertIndexCmpPattern");
322struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
326 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter)
const override {
328 auto *typeConverter = this->
template getTypeConverter<SPIRVTypeConverter>();
329 Type indexType = typeConverter->getIndexType();
330 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
331 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
332 op, indexType, IntegerAttr::get(indexType, bitwidth));
363 ConvertIndexConstantBoolOpPattern,
364 ConvertIndexConstantOpPattern,
365 ConvertIndexCeilDivSPattern,
366 ConvertIndexCeilDivUPattern,
367 ConvertIndexFloorDivSPattern,
370 ConvertIndexCmpPattern,
372 >(typeConverter,
patterns.getContext());
380#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
381#include "mlir/Conversion/Passes.h.inc"
389struct ConvertIndexToSPIRVPass
393 void runOnOperation()
override {
394 Operation *op = getOperation();
396 std::unique_ptr<SPIRVConversionTarget>
target =
399 SPIRVConversionOptions
options;
400 options.use64bitIndex = this->use64bitIndex;
401 SPIRVTypeConverter typeConverter(targetAttr,
options);
405 target->addLegalOp<UnrealizedConversionCastOp>();
408 target->addLegalDialect<spirv::SPIRVDialect>();
410 target->addIllegalDialect<index::IndexDialect>();
static llvm::ManagedStatic< PassManagerOptions > options
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getType() const
Return the type of this value.
void populateIndexToSPIRVPatterns(const SPIRVTypeConverter &converter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.