124 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
125 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
126 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
128 if (!convOp.hasPureTensorSemantics())
130 convOp,
"expected op to have pure tensor semantics");
132 if (!filterType.hasStaticShape())
134 convOp,
"expected a static shape for the filter");
136 if (!inputType.hasStaticShape())
138 "expected a static shape for the input");
143 "expected all ones for dilations");
146 Value input = convOp.getInputs()[0];
147 Value filter = convOp.getInputs()[1];
148 Value output = convOp.getOutputs()[0];
163 assert(isa<RankedTensorType>(filterType) &&
164 "expected filter type to be a ranked tensor");
165 auto tensorFilterType = cast<RankedTensorType>(filterType);
169 auto reshapedFilterType =
170 RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType(),
171 tensorFilterType.getEncoding());
172 Value reshapedFilter = tensor::CollapseShapeOp::create(
173 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
176 RankedTensorType reshapedOutputType =
177 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
178 Value reshapedOutput = tensor::CollapseShapeOp::create(
179 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
182 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
183 inputType.getElementType());
186 auto nloops = colTensorShape.size();
188 auto parallel = utils::IteratorType::parallel;
189 auto reduction = utils::IteratorType::reduction;
199 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
200 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
201 i2cToOperExprs.
icIndex = kIndicesExprs[2];
202 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
203 i2cToOperExprs.
owIndex = mIndicesExprs[1];
207 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
217 auto img2ColTensor = linalg::GenericOp::create(
218 rewriter, loc, colTensor.
getType(),
219 input, colTensor, img2colIndexingMaps,
222 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
230 bindDims(context, bDim, mDim, nDim, kDim);
233 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
235 parallel, reduction};
237 auto genericOp = linalg::GenericOp::create(
238 rewriter, loc, reshapedOutputType,
239 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
246 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
250 auto reshapedResult = tensor::ExpandShapeOp::create(
251 rewriter, loc, outputType,
result, outputReassocIndices);
255 return std::make_pair(img2ColTensor.getOperation(),
256 reshapedResult.getOperation());
261 linalg::DepthwiseConv2DNhwcHwcOp convOp) {
262 auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
263 auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
264 auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
266 if (!convOp.hasPureTensorSemantics())
268 convOp,
"expected op to have pure tensor semantics");
270 if (!filterType.hasStaticShape())
272 convOp,
"expected a static shape for the filter");
274 if (!inputType.hasStaticShape())
276 "expected a static shape for the input");
281 "expected all ones for dilations");
286 auto operandTensorType = cast<RankedTensorType>(operand.
getType());
298 Value outputTensor = tensor::EmptyOp::create(
299 rewriter, loc, targetShape, operandTensorType.getElementType());
302 nloops, utils::IteratorType::parallel);
309 auto transposedOp = linalg::GenericOp::create(
310 rewriter, loc, outputTensor.getType(),
311 operand, outputTensor, indexingMaps,
314 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
317 return transposedOp.getResult(0);
320 Value input = convOp.getInputs()[0];
321 Value filter = convOp.getInputs()[1];
322 Value output = convOp.getOutputs()[0];
325 Value inputT = transposeOperand(input, {0, 3, 1, 2});
326 Value filterT = transposeOperand(filter, {2, 0, 1});
328 cast<RankedTensorType>(filterT.getType()).getShape();
331 int n = outputShape[0];
332 int oh = outputShape[1];
333 int ow = outputShape[2];
334 int c = outputShape[3];
335 int fh = filterTShape[1];
336 int fw = filterTShape[2];
339 Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
341 AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
345 convOp.getStrides().getValues<
int64_t>()[0]);
347 convOp.getStrides().getValues<
int64_t>()[1]);
350 owDim * swSym + kwDim};
352 auto nloops = colTensorShape.size();
355 nloops, utils::IteratorType::parallel);
361 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
362 inputType.getElementType());
364 auto img2ColTensor = linalg::GenericOp::create(
365 rewriter, loc, colTensor.
getType(),
366 inputT, colTensor, indexingMaps,
369 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
373 {0, 1}, {2, 3}, {4, 5}};
378 auto reshapedImg2ColTensorType = RankedTensorType::get(
379 {n * c, oh * ow, fh * fw}, inputType.getElementType());
380 auto reshapedFilterTensorType =
381 RankedTensorType::get({c, fh * fw}, filterType.getElementType());
382 auto reshapedOutputTensorType =
383 RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
385 Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create(
386 rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
387 img2ColTensorReassocIndices);
388 Value reshapedFilterTensor =
389 tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType,
390 filterT, filterReassociationIndice);
391 Value reshapedoutputTensor = tensor::CollapseShapeOp::create(
392 rewriter, loc, reshapedOutputTensorType, transposedOutputTensor,
393 outputReassociationIndice);
395 auto batchMatVecResult = linalg::BatchMatvecOp::create(
397 ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
403 auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create(
404 rewriter, loc, transposedOutputTensor.
getType(),
405 batchMatVecResult.getResult(0), batchMatVecReassociationIndice);
407 Value transposedResult =
408 transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
411 return std::make_pair(img2ColTensor.getOperation(),
417 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
418 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
419 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
421 if (!convOp.hasPureTensorSemantics())
423 convOp,
"expected op to have pure tensor semantics");
425 if (!filterType.hasStaticShape())
427 convOp,
"expected a static shape for the filter");
429 if (!inputType.hasStaticShape())
431 "expected a static shape for the input");
436 "expected all ones for dilations");
438 Value input = convOp.getInputs()[0];
439 Value filter = convOp.getInputs()[1];
440 Value output = convOp.getOutputs()[0];
442 auto filterShape = filterType.getShape();
443 auto outputShape = outputType.getShape();
453 auto loc = convOp.getLoc();
456 assert(isa<RankedTensorType>(filterType) &&
457 "expected filter type to be a ranked tensor");
458 auto tensorFilterType = cast<RankedTensorType>(filterType);
461 auto reshapedFilterType =
462 RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType(),
463 tensorFilterType.getEncoding());
464 Value reshapedFilter = tensor::CollapseShapeOp::create(
465 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
468 auto reshapedOutputType =
469 RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
470 Value reshapedOutput = tensor::CollapseShapeOp::create(
471 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
475 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
476 inputType.getElementType());
478 auto nloops = colTensorShape.size();
480 auto parallel = utils::IteratorType::parallel;
481 auto reduction = utils::IteratorType::reduction;
492 i2cToOperExprs.
icIndex = kIndicesExprs[0];
493 i2cToOperExprs.
fhIndex = kIndicesExprs[1];
494 i2cToOperExprs.
fwIndex = kIndicesExprs[2];
495 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
496 i2cToOperExprs.
owIndex = mIndicesExprs[1];
498 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
508 auto img2ColTensor = linalg::GenericOp::create(
509 rewriter, loc, colTensor.
getType(),
510 input, colTensor, img2colIndexingMaps,
513 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
521 bindDims(context, bDim, mDim, nDim, kDim);
524 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
526 parallel, reduction};
527 auto genericOp = linalg::GenericOp::create(
528 rewriter, loc, reshapedOutputType,
529 ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
536 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
540 auto reshapedResult = tensor::ExpandShapeOp::create(
541 rewriter, loc, outputType,
result, outputReassocIndices);
545 return std::make_pair(img2ColTensor.getOperation(),
546 reshapedResult.getOperation());
551 auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
552 auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
553 auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
555 if (!convOp.hasPureTensorSemantics())
557 convOp,
"expected op to have pure tensor semantics");
559 if (!filterType.hasStaticShape())
561 convOp,
"expected a static shape for the filter");
563 if (!inputType.hasStaticShape())
565 "expected a static shape for the input");
570 "expected all ones for dilations");
573 Value input = convOp.getInputs()[0];
574 Value filter = convOp.getInputs()[1];
575 Value output = convOp.getOutputs()[0];
590 assert(isa<RankedTensorType>(filterType) &&
591 "expected filter type to be a ranked tensor");
592 auto tensorFilterType = cast<RankedTensorType>(filterType);
597 auto reshapedFilterType =
598 RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType(),
599 tensorFilterType.getEncoding());
600 Value reshapedFilter = tensor::CollapseShapeOp::create(
601 rewriter, loc, reshapedFilterType, filter, filterReassocIndices);
604 RankedTensorType reshapedOutputType =
605 RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
606 Value reshapedOutput = tensor::CollapseShapeOp::create(
607 rewriter, loc, reshapedOutputType, output, outputReassocIndices);
611 Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape,
612 inputType.getElementType());
615 auto nloops = colTensorShape.size();
617 auto parallel = utils::IteratorType::parallel;
618 auto reduction = utils::IteratorType::reduction;
628 i2cToOperExprs.
fhIndex = kIndicesExprs[0];
629 i2cToOperExprs.
fwIndex = kIndicesExprs[1];
630 i2cToOperExprs.
icIndex = kIndicesExprs[2];
631 i2cToOperExprs.
ohIndex = mIndicesExprs[0];
632 i2cToOperExprs.
owIndex = mIndicesExprs[1];
636 i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<
int64_t>()),
645 auto img2ColTensor = linalg::GenericOp::create(
646 rewriter, loc, colTensor.
getType(),
647 input, colTensor, img2colIndexingMaps,
650 linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]);
657 bindDims(context, bDim, mDim, nDim, kDim);
660 auto resultMap =
AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
662 parallel, reduction};
664 auto genericOp = linalg::GenericOp::create(
665 rewriter, loc, reshapedOutputType,
666 ValueRange{img2ColTensor.getResult(0), reshapedFilter},
673 linalg::YieldOp::create(nestedBuilder, nestedLoc,
add);
677 auto reshapedResult = tensor::ExpandShapeOp::create(
678 rewriter, loc, outputType,
result, outputReassocIndices);
682 return std::make_pair(img2ColTensor.getOperation(),
683 reshapedResult.getOperation());
MLIRContext is the top-level object for a collection of MLIR operations.